funasr/train_utils/trainer.py
@@ -398,7 +398,7 @@ speed_stats = {} time5 = time.perf_counter() # iterator_stop = torch.tensor(0).to(self.device) dataloader_val.batch_sampler.set_epoch(epoch) for batch_idx, batch in enumerate(dataloader_val): # if self.use_ddp or self.use_fsdp: # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)