funasr/train_utils/trainer.py
@@ -417,7 +417,7 @@ # Apply weighted averaging for loss and stats loss = (loss * weight.type(loss.dtype)).sum() # if distributed, this method can also apply all_reduce() stats, weight = recursive_average(stats, weight, distributed=True) # stats, weight = recursive_average(stats, weight, distributed=True) if self.use_ddp or self.use_fsdp: dist.all_reduce(weight, op=dist.ReduceOp.SUM) # Now weight is summation over all workers