funasr/bin/train.py
@@ -154,7 +154,7 @@ if batch_sampler is not None: batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler,