From 5023dd04224eddd4c9a047bd946695c3932743ae Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 15 三月 2024 16:24:29 +0800
Subject: [PATCH] Dev gzf llm (#1503)

---
 funasr/datasets/llm_datasets_vicuna/samplers.py |    4 ++--
 1 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/funasr/datasets/llm_datasets_vicuna/samplers.py b/funasr/datasets/llm_datasets_vicuna/samplers.py
index fe840e2..c728d9c 100644
--- a/funasr/datasets/llm_datasets_vicuna/samplers.py
+++ b/funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,9 +142,9 @@
     def set_epoch(self, epoch):
         self.epoch = epoch
 
-
+@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
 def CustomDistributedBatchSampler_fn(dataset, **kwargs):
-    dataloader_args = {"dataset": dataset}
+    dataloader_args = {}
     dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
     dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
     dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)

--
Gitblit v1.9.1