游雁
2023-12-21 c8bae0ec85eee25d66de6b1e4502eff74d750b24
funasr/bin/train.py
@@ -145,7 +145,8 @@
   # dataloader
   batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
   batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
   batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
   if batch_sampler is not None:
      batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
   dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                               collate_fn=dataset_tr.collator,
                                               batch_sampler=batch_sampler,
@@ -153,7 +154,6 @@
                                               pin_memory=True)
   
   trainer = Trainer(
       model=model,
       optim=optim,