雾聪
2024-01-08 2acef4bdaea588adee3098a057a395937dff4e6a
funasr/models/e2e_asr.py
@@ -11,7 +11,6 @@
from typing import Union
import torch
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -65,7 +64,6 @@
            preencoder: Optional[AbsPreEncoder] = None,
            postencoder: Optional[AbsPostEncoder] = None,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -124,6 +122,7 @@
            self.ctc = ctc
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.length_normalized_loss = length_normalized_loss
    def forward(
            self,
@@ -222,6 +221,8 @@
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight