kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/sa_asr/e2e_sa_asr.py
@@ -14,19 +14,17 @@
import torch.nn.functional as F
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss, NllLoss  # noqa: H301
)
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, NllLoss  # noqa: H301
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.transformer.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics import ErrorCalculator
from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@@ -43,28 +41,28 @@
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
            self,
            vocab_size: int,
            max_spk_num: int,
            token_list: Union[Tuple[str, ...], List[str]],
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            asr_encoder: AbsEncoder,
            spk_encoder: torch.nn.Module,
            decoder: AbsDecoder,
            ctc: CTC,
            spk_weight: float = 0.5,
            ctc_weight: float = 0.5,
            interctc_weight: float = 0.0,
            ignore_id: int = -1,
            lsm_weight: float = 0.0,
            length_normalized_loss: bool = False,
            report_cer: bool = True,
            report_wer: bool = True,
            sym_space: str = "<space>",
            sym_blank: str = "<blank>",
            extract_feats_in_collect_stats: bool = True,
        self,
        vocab_size: int,
        max_spk_num: int,
        token_list: Union[Tuple[str, ...], List[str]],
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        asr_encoder: AbsEncoder,
        spk_encoder: torch.nn.Module,
        decoder: AbsDecoder,
        ctc: CTC,
        spk_weight: float = 0.5,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        ignore_id: int = -1,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        extract_feats_in_collect_stats: bool = True,
    ):
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -75,7 +73,7 @@
        self.sos = 1
        self.eos = 2
        self.vocab_size = vocab_size
        self.max_spk_num=max_spk_num
        self.max_spk_num = max_spk_num
        self.ignore_id = ignore_id
        self.spk_weight = spk_weight
        self.ctc_weight = ctc_weight
@@ -96,7 +94,6 @@
            )
        self.error_calculator = None
        # we set self.decoder = None in the CTC mode since
        # self.decoder parameters were never used and PyTorch complained
@@ -133,15 +130,15 @@
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
            profile: torch.Tensor,
            profile_lengths: torch.Tensor,
            text_id: torch.Tensor,
            text_id_lengths: torch.Tensor
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        profile: torch.Tensor,
        profile_lengths: torch.Tensor,
        text_id: torch.Tensor,
        text_id_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
@@ -156,10 +153,7 @@
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
                == text.shape[0]
                == text_lengths.shape[0]
            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        batch_size = speech.shape[0]
@@ -183,7 +177,6 @@
                asr_encoder_out, encoder_out_lens, text, text_lengths
            )
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
@@ -204,15 +197,20 @@
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                               1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
            loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
        # 2b. Attention decoder branch
        if self.ctc_weight != 1.0:
            loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
                asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
                asr_encoder_out,
                spk_encoder_out,
                encoder_out_lens,
                text,
                text_lengths,
                profile,
                profile_lengths,
                text_id,
                text_id_lengths,
            )
        # 3. CTC-Att loss definition
@@ -227,7 +225,6 @@
            loss = loss_asr
        else:
            loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
        stats = dict(
            loss=loss.detach(),
@@ -247,11 +244,11 @@
        return loss, stats, weight
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
@@ -266,7 +263,7 @@
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
@@ -291,9 +288,7 @@
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.asr_encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.asr_encoder(
                feats, feats_lengths, ctc=self.ctc
            )
            encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths, ctc=self.ctc)
        else:
            encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
        intermediate_outs = None
@@ -303,10 +298,12 @@
        encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
        # import ipdb;ipdb.set_trace()
        if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
        if encoder_out_spk_ori.size(1) != encoder_out.size(1):
            encoder_out_spk = F.interpolate(
                encoder_out_spk_ori.transpose(-2, -1), size=(encoder_out.size(1)), mode="nearest"
            ).transpose(-2, -1)
        else:
            encoder_out_spk=encoder_out_spk_ori
            encoder_out_spk = encoder_out_spk_ori
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -327,7 +324,7 @@
        return encoder_out, encoder_out_lens, encoder_out_spk
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
@@ -346,11 +343,11 @@
        return feats, feats_lengths
    def nll(
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ) -> torch.Tensor:
        """Compute negative log likelihood(nll) from transformer-decoder
@@ -384,12 +381,12 @@
        return nll
    def batchify_nll(
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
            batch_size: int = 100,
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        batch_size: int = 100,
    ):
        """Compute negative log likelihood(nll) from transformer-decoder
@@ -431,28 +428,34 @@
        return nll
    def _calc_att_loss(
            self,
            asr_encoder_out: torch.Tensor,
            spk_encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
            profile: torch.Tensor,
            profile_lens: torch.Tensor,
            text_id: torch.Tensor,
            text_id_lengths: torch.Tensor
        self,
        asr_encoder_out: torch.Tensor,
        spk_encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        profile: torch.Tensor,
        profile_lens: torch.Tensor,
        text_id: torch.Tensor,
        text_id_lengths: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        # 1. Forward decoder
        decoder_out, weights_no_pad, _ = self.decoder(
            asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
            asr_encoder_out,
            spk_encoder_out,
            encoder_out_lens,
            ys_in_pad,
            ys_in_lens,
            profile,
            profile_lens,
        )
        spk_num_no_pad=weights_no_pad.size(-1)
        pad=(0,self.max_spk_num-spk_num_no_pad)
        weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
        spk_num_no_pad = weights_no_pad.size(-1)
        pad = (0, self.max_spk_num - spk_num_no_pad)
        weights = F.pad(weights_no_pad, pad, mode="constant", value=0)
        # pre_id=weights.argmax(-1)
        # pre_text=decoder_out.argmax(-1)
@@ -467,7 +470,7 @@
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        loss_spk = self.criterion_spk(torch.log(weights), text_id)
        acc_spk= th_accuracy(
        acc_spk = th_accuracy(
            weights.view(-1, self.max_spk_num),
            text_id,
            ignore_label=self.ignore_id,
@@ -488,11 +491,11 @@
        return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
    def _calc_ctc_loss(
            self,
            encoder_out: torch.Tensor,
            encoder_out_lens: torch.Tensor,
            ys_pad: torch.Tensor,
            ys_pad_lens: torch.Tensor,
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)