hnluo
2023-12-07 73959c6e8e80a6e735bbb7d63acf942a6e2f652d
funasr/models/e2e_uni_asr.py
@@ -167,6 +167,7 @@
        self.enable_maas_finetune = enable_maas_finetune
        self.freeze_encoder2 = freeze_encoder2
        self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
        self.length_normalized_loss = length_normalized_loss
    def forward(
        self,
@@ -440,6 +441,8 @@
        stats["loss2"] = torch.clone(loss2.detach())
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + 1).sum().type_as(batch_size)
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight