雾聪
2024-01-08 2acef4bdaea588adee3098a057a395937dff4e6a
funasr/models/e2e_uni_asr.py
@@ -8,7 +8,6 @@
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
@@ -82,7 +81,6 @@
        postencoder: Optional[AbsPostEncoder] = None,
        encoder1_encoder2_joint_training: bool = True,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -169,6 +167,7 @@
        self.enable_maas_finetune = enable_maas_finetune
        self.freeze_encoder2 = freeze_encoder2
        self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
        self.length_normalized_loss = length_normalized_loss
    def forward(
        self,
@@ -442,6 +441,8 @@
        stats["loss2"] = torch.clone(loss2.detach())
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight