雾聪
2024-02-04 45dc927a97831c85a422fee5ee3010b2dc0ce3c7
funasr/bin/train.py
@@ -154,7 +154,7 @@
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,