游雁
2024-02-05 08e6f946aa4c02e3ba3cebb47af2ac3cd5abe97b
funasr/bin/train.py
@@ -40,7 +40,7 @@
def main(**kwargs):
    print(kwargs)
    # set random seed
    tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
@@ -154,7 +154,7 @@
    if batch_sampler is not None:
        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,