雾聪
2024-01-01 f047ff8aa4fb11215a3eb5a425fee77ff72e5dba
funasr/datasets/dataloader_fn.py
@@ -38,13 +38,16 @@
batch_sampler = BatchSampler(dataset)
def collator(samples: list = None):
    return samples
if __name__ == "__main__":
    
    dataloader_tr = torch.utils.data.DataLoader(dataset,
                                                collate_fn=dataset.collator,
                                                batch_sampler=batch_sampler,
                                                shuffle=False,
                                                num_workers=0,
                                                num_workers=8,
                                                pin_memory=True)
    
    print(len(dataset))