funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,9 +142,9 @@ def set_epoch(self, epoch): self.epoch = epoch @tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn") def CustomDistributedBatchSampler_fn(dataset, **kwargs): dataloader_args = {"dataset": dataset} dataloader_args = {} dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs) dataloader_args["num_workers"] = kwargs.get("num_workers", 4) dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)