From cbe2ea7e07cbf364827bd89cefc42b3f643ea3be Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 18 三月 2024 23:59:09 +0800
Subject: [PATCH] paraformer streaming bugfix

---
 funasr/datasets/llm_datasets_vicuna/samplers.py |    9 +++++++++
 1 files changed, 9 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index 512250d..c728d9c 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,6 +142,15 @@
     def set_epoch(self, epoch):
         self.epoch = epoch
 
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
+def CustomDistributedBatchSampler_fn(dataset, **kwargs):
+    dataloader_args = {}
+    dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
+    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
+    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
+    
+    return dataloader_args
+
 @tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
 class CustomDistributedBatchSampler(Sampler):
     def __init__(self, dataset,

--
Gitblit v1.9.1