zhifu gao
2024-04-29 11cf10e433c173efd892766b669e0bba57253fed
funasr/models/sense_voice/model.py
@@ -329,6 +329,8 @@
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        stats["batch_size_x_frames"] = frames * batch_size
        stats["batch_size_real_frames"] = speech_lengths.sum().item()
        stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss: