游雁
2024-03-24 5d74e107fc5696b70e75003c278f8babd17161e8
update
1个文件已修改
2 ■■■ 已修改文件
funasr/train_utils/trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer.py
@@ -417,7 +417,7 @@
                    # Apply weighted averaging for loss and stats
                    loss = (loss * weight.type(loss.dtype)).sum()
                    # if distributed, this method can also apply all_reduce()
                    stats, weight = recursive_average(stats, weight, distributed=True)
                    # stats, weight = recursive_average(stats, weight, distributed=True)
                    if self.use_ddp or self.use_fsdp:
                        dist.all_reduce(weight, op=dist.ReduceOp.SUM)
                    # Now weight is summation over all workers