zhifu gao
2024-03-21 3ac03e448b7673604eb86f619b27521fca55f34d
funasr/models/paraformer/model.py
@@ -231,6 +231,7 @@
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss: