游雁
2024-03-21 bd2b6f6a116f9cd4425c270942a3b45d9a7901c0
funasr/models/paraformer/model.py
@@ -231,6 +231,7 @@
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss: