From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/sa_asr/e2e_sa_asr.py |  191 ++++++++++++++++++++++++-----------------------
 1 files changed, 97 insertions(+), 94 deletions(-)

diff --git a/funasr/models/sa_asr/e2e_sa_asr.py b/funasr/models/sa_asr/e2e_sa_asr.py
index e0cb69a..f4827b2 100644
--- a/funasr/models/sa_asr/e2e_sa_asr.py
+++ b/funasr/models/sa_asr/e2e_sa_asr.py
@@ -14,19 +14,17 @@
 import torch.nn.functional as F
 
 from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-    LabelSmoothingLoss, NllLoss  # noqa: H301
-)
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, NllLoss  # noqa: H301
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 
@@ -43,28 +41,28 @@
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
-            self,
-            vocab_size: int,
-            max_spk_num: int,
-            token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
-            asr_encoder: AbsEncoder,
-            spk_encoder: torch.nn.Module,
-            decoder: AbsDecoder,
-            ctc: CTC,
-            spk_weight: float = 0.5,
-            ctc_weight: float = 0.5,
-            interctc_weight: float = 0.0,
-            ignore_id: int = -1,
-            lsm_weight: float = 0.0,
-            length_normalized_loss: bool = False,
-            report_cer: bool = True,
-            report_wer: bool = True,
-            sym_space: str = "<space>",
-            sym_blank: str = "<blank>",
-            extract_feats_in_collect_stats: bool = True,
+        self,
+        vocab_size: int,
+        max_spk_num: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        asr_encoder: AbsEncoder,
+        spk_encoder: torch.nn.Module,
+        decoder: AbsDecoder,
+        ctc: CTC,
+        spk_weight: float = 0.5,
+        ctc_weight: float = 0.5,
+        interctc_weight: float = 0.0,
+        ignore_id: int = -1,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        extract_feats_in_collect_stats: bool = True,
     ):
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -75,7 +73,7 @@
         self.sos = 1
         self.eos = 2
         self.vocab_size = vocab_size
-        self.max_spk_num=max_spk_num
+        self.max_spk_num = max_spk_num
         self.ignore_id = ignore_id
         self.spk_weight = spk_weight
         self.ctc_weight = ctc_weight
@@ -96,7 +94,6 @@
             )
 
         self.error_calculator = None
-
 
         # we set self.decoder = None in the CTC mode since
         # self.decoder parameters were never used and PyTorch complained
@@ -133,15 +130,15 @@
         self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
-            profile: torch.Tensor,
-            profile_lengths: torch.Tensor,
-            text_id: torch.Tensor,
-            text_id_lengths: torch.Tensor
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        profile: torch.Tensor,
+        profile_lengths: torch.Tensor,
+        text_id: torch.Tensor,
+        text_id_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
 
@@ -156,10 +153,7 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
+            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
         batch_size = speech.shape[0]
 
@@ -183,7 +177,6 @@
                 asr_encoder_out, encoder_out_lens, text, text_lengths
             )
 
-
         # Intermediate CTC (optional)
         loss_interctc = 0.0
         if self.interctc_weight != 0.0 and intermediate_outs is not None:
@@ -204,15 +197,20 @@
             loss_interctc = loss_interctc / len(intermediate_outs)
 
             # calculate whole encoder loss
-            loss_ctc = (
-                               1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
-
+            loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
 
         # 2b. Attention decoder branch
         if self.ctc_weight != 1.0:
             loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att = self._calc_att_loss(
-                asr_encoder_out, spk_encoder_out, encoder_out_lens, text, text_lengths, profile, profile_lengths, text_id, text_id_lengths
+                asr_encoder_out,
+                spk_encoder_out,
+                encoder_out_lens,
+                text,
+                text_lengths,
+                profile,
+                profile_lengths,
+                text_id,
+                text_id_lengths,
             )
 
         # 3. CTC-Att loss definition
@@ -227,7 +225,6 @@
             loss = loss_asr
         else:
             loss = self.spk_weight * loss_spk + (1 - self.spk_weight) * loss_asr
-
 
         stats = dict(
             loss=loss.detach(),
@@ -247,11 +244,11 @@
         return loss, stats, weight
 
     def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         if self.extract_feats_in_collect_stats:
             feats, feats_lengths = self._extract_feats(speech, speech_lengths)
@@ -266,7 +263,7 @@
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
 
@@ -291,9 +288,7 @@
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
         if self.asr_encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.asr_encoder(
-                feats, feats_lengths, ctc=self.ctc
-            )
+            encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths, ctc=self.ctc)
         else:
             encoder_out, encoder_out_lens, _ = self.asr_encoder(feats, feats_lengths)
         intermediate_outs = None
@@ -303,10 +298,12 @@
 
         encoder_out_spk_ori = self.spk_encoder(feats_raw, feats_lengths)[0]
         # import ipdb;ipdb.set_trace()
-        if encoder_out_spk_ori.size(1)!=encoder_out.size(1):
-            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
+        if encoder_out_spk_ori.size(1) != encoder_out.size(1):
+            encoder_out_spk = F.interpolate(
+                encoder_out_spk_ori.transpose(-2, -1), size=(encoder_out.size(1)), mode="nearest"
+            ).transpose(-2, -1)
         else:
-            encoder_out_spk=encoder_out_spk_ori
+            encoder_out_spk = encoder_out_spk_ori
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -327,7 +324,7 @@
         return encoder_out, encoder_out_lens, encoder_out_spk
 
     def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
 
@@ -346,11 +343,11 @@
         return feats, feats_lengths
 
     def nll(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
 
@@ -384,12 +381,12 @@
         return nll
 
     def batchify_nll(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-            batch_size: int = 100,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+        batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
 
@@ -431,28 +428,34 @@
         return nll
 
     def _calc_att_loss(
-            self,
-            asr_encoder_out: torch.Tensor,
-            spk_encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
-            profile: torch.Tensor,
-            profile_lens: torch.Tensor,
-            text_id: torch.Tensor,
-            text_id_lengths: torch.Tensor
+        self,
+        asr_encoder_out: torch.Tensor,
+        spk_encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+        profile: torch.Tensor,
+        profile_lens: torch.Tensor,
+        text_id: torch.Tensor,
+        text_id_lengths: torch.Tensor,
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
 
         # 1. Forward decoder
         decoder_out, weights_no_pad, _ = self.decoder(
-            asr_encoder_out, spk_encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens, profile, profile_lens
+            asr_encoder_out,
+            spk_encoder_out,
+            encoder_out_lens,
+            ys_in_pad,
+            ys_in_lens,
+            profile,
+            profile_lens,
         )
 
-        spk_num_no_pad=weights_no_pad.size(-1)
-        pad=(0,self.max_spk_num-spk_num_no_pad)
-        weights=F.pad(weights_no_pad, pad, mode='constant', value=0)
+        spk_num_no_pad = weights_no_pad.size(-1)
+        pad = (0, self.max_spk_num - spk_num_no_pad)
+        weights = F.pad(weights_no_pad, pad, mode="constant", value=0)
 
         # pre_id=weights.argmax(-1)
         # pre_text=decoder_out.argmax(-1)
@@ -467,7 +470,7 @@
         loss_att = self.criterion_att(decoder_out, ys_out_pad)
         loss_spk = self.criterion_spk(torch.log(weights), text_id)
 
-        acc_spk= th_accuracy(
+        acc_spk = th_accuracy(
             weights.view(-1, self.max_spk_num),
             text_id,
             ignore_label=self.ignore_id,
@@ -488,11 +491,11 @@
         return loss_att, loss_spk, acc_att, acc_spk, cer_att, wer_att
 
     def _calc_ctc_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
     ):
         # Calc CTC loss
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)

--
Gitblit v1.9.1