游雁
2024-03-18 cbe2ea7e07cbf364827bd89cefc42b3f643ea3be
funasr/datasets/llm_datasets_vicuna/samplers.py
@@ -142,6 +142,15 @@
    def set_epoch(self, epoch):
        self.epoch = epoch
@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler_fn")
def CustomDistributedBatchSampler_fn(dataset, **kwargs):
    dataloader_args = {}
    dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs)
    dataloader_args["num_workers"] = kwargs.get("num_workers", 4)
    dataloader_args["pin_memory"] = kwargs.get("pin_memory", True)
    return dataloader_args
@tables.register("batch_sampler_classes", "CustomDistributedBatchSampler")
class CustomDistributedBatchSampler(Sampler):
    def __init__(self, dataset,