haoneng.lhn
2023-12-08 ace42e5043be992c3a324fc637ba0594f10ad3b6
funasr/models/e2e_uni_asr.py
@@ -167,6 +167,7 @@
        self.enable_maas_finetune = enable_maas_finetune
        self.freeze_encoder2 = freeze_encoder2
        self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
        self.length_normalized_loss = length_normalized_loss
    def forward(
        self,
@@ -440,6 +441,8 @@
        stats["loss2"] = torch.clone(loss2.detach())
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight