From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/mfcca/e2e_asr_mfcca.py |  124 ++++++++++++++++++++---------------------
 1 files changed, 61 insertions(+), 63 deletions(-)

diff --git a/funasr/models/mfcca/e2e_asr_mfcca.py b/funasr/models/mfcca/e2e_asr_mfcca.py
index 5ec9e94..681d1f6 100644
--- a/funasr/models/mfcca/e2e_asr_mfcca.py
+++ b/funasr/models/mfcca/e2e_asr_mfcca.py
@@ -9,15 +9,15 @@
 import torch
 
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
@@ -31,9 +31,12 @@
     @contextmanager
     def autocast(enabled=True):
         yield
+
+
 import pdb
 import random
 import math
+
 
 class MFCCA(FunASRModel):
     """
@@ -43,26 +46,26 @@
     """
 
     def __init__(
-            self,
-            vocab_size: int,
-            token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
-            encoder: AbsEncoder,
-            decoder: AbsDecoder,
-            ctc: CTC,
-            rnnt_decoder: None = None,
-            ctc_weight: float = 0.5,
-            ignore_id: int = -1,
-            lsm_weight: float = 0.0,
-            mask_ratio: float = 0.0,
-            length_normalized_loss: bool = False,
-            report_cer: bool = True,
-            report_wer: bool = True,
-            sym_space: str = "<space>",
-            sym_blank: str = "<blank>",
-            preencoder: Optional[AbsPreEncoder] = None,
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder: AbsEncoder,
+        decoder: AbsDecoder,
+        ctc: CTC,
+        rnnt_decoder: None = None,
+        ctc_weight: float = 0.5,
+        ignore_id: int = -1,
+        lsm_weight: float = 0.0,
+        mask_ratio: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        preencoder: Optional[AbsPreEncoder] = None,
     ):
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert rnnt_decoder is None, "Not implemented"
@@ -111,11 +114,11 @@
             self.error_calculator = None
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -127,18 +130,15 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
+            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
         # pdb.set_trace()
-        if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
+        if speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0:
             rate_num = random.random()
             # rate_num = 0.1
-            if (rate_num <= self.mask_ratio):
+            if rate_num <= self.mask_ratio:
                 retain_channel = math.ceil(random.random() * 8)
-                if (retain_channel > 1):
+                if retain_channel > 1:
                     speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
                 else:
                     speech = speech[:, :, torch.randperm(8)[0]]
@@ -192,17 +192,17 @@
         return loss, stats, weight
 
     def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -230,7 +230,7 @@
             encoder_out.size(),
             speech.size(0),
         )
-        if (encoder_out.dim() == 4):
+        if encoder_out.dim() == 4:
             assert encoder_out.size(2) <= encoder_out_lens.max(), (
                 encoder_out.size(),
                 encoder_out_lens.max(),
@@ -244,7 +244,7 @@
         return encoder_out, encoder_out_lens
 
     def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
         # for data-parallel
@@ -262,19 +262,17 @@
         return feats, feats_lengths, channel_size
 
     def _calc_att_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
 
         # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
+        decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
 
         # 2. Compute attention loss
         loss_att = self.criterion_att(decoder_out, ys_out_pad)
@@ -294,14 +292,14 @@
         return loss_att, acc_att, cer_att, wer_att
 
     def _calc_ctc_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
     ):
         # Calc CTC loss
-        if (encoder_out.dim() == 4):
+        if encoder_out.dim() == 4:
             encoder_out = encoder_out.mean(1)
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
 
@@ -313,10 +311,10 @@
         return loss_ctc, cer_ctc
 
     def _calc_rnnt_loss(
-            self,
-            encoder_out: torch.Tensor,
-            encoder_out_lens: torch.Tensor,
-            ys_pad: torch.Tensor,
-            ys_pad_lens: torch.Tensor,
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
     ):
-        raise NotImplementedError
\ No newline at end of file
+        raise NotImplementedError

--
Gitblit v1.9.1