zhifu gao
2024-02-01 2ddfc27d5b69e5c1a38021554e97dae958328c20
funasr/bin/train.py
@@ -154,7 +154,7 @@
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,