From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/xvector/e2e_sv.py |   67 +++++++++++++++------------------
 1 files changed, 30 insertions(+), 37 deletions(-)

diff --git a/funasr/models/xvector/e2e_sv.py b/funasr/models/xvector/e2e_sv.py
index 3eac9ef..56c9e3b 100644
--- a/funasr/models/xvector/e2e_sv.py
+++ b/funasr/models/xvector/e2e_sv.py
@@ -20,13 +20,13 @@
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 
@@ -43,17 +43,17 @@
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
-            self,
-            vocab_size: int,
-            token_list: Union[Tuple[str, ...], List[str]],
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
-            encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
-            pooling_layer: torch.nn.Module,
-            decoder: AbsDecoder,
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        preencoder: Optional[AbsPreEncoder],
+        encoder: AbsEncoder,
+        postencoder: Optional[AbsPostEncoder],
+        pooling_layer: torch.nn.Module,
+        decoder: AbsDecoder,
     ):
 
         super().__init__()
@@ -71,11 +71,11 @@
         self.decoder = decoder
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -87,10 +87,7 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-                == text.shape[0]
-                == text_lengths.shape[0]
+            speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
         batch_size = speech.shape[0]
 
@@ -139,9 +136,7 @@
             loss_interctc = loss_interctc / len(intermediate_outs)
 
             # calculate whole encoder loss
-            loss_ctc = (
-                               1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
+            loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
 
         if self.use_transducer_decoder:
             # 2a. Transducer decoder branch
@@ -196,11 +191,11 @@
         return loss, stats, weight
 
     def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            text: torch.Tensor,
-            text_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         if self.extract_feats_in_collect_stats:
             feats, feats_lengths = self._extract_feats(speech, speech_lengths)
@@ -215,7 +210,7 @@
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
@@ -244,14 +239,12 @@
 
         # Post-encoder, e.g. NLU
         if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
+            encoder_out, encoder_out_lens = self.postencoder(encoder_out, encoder_out_lens)
 
         return encoder_out, encoder_out_lens
 
     def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
 
@@ -267,4 +260,4 @@
         else:
             # No frontend and no feature extract
             feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
\ No newline at end of file
+        return feats, feats_lengths

--
Gitblit v1.9.1