From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/xvector/e2e_sv.py | 67 +++++++++++++++------------------
1 files changed, 30 insertions(+), 37 deletions(-)
diff --git a/funasr/models/xvector/e2e_sv.py b/funasr/models/xvector/e2e_sv.py
index 3eac9ef..56c9e3b 100644
--- a/funasr/models/xvector/e2e_sv.py
+++ b/funasr/models/xvector/e2e_sv.py
@@ -20,13 +20,13 @@
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
@@ -43,17 +43,17 @@
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- pooling_layer: torch.nn.Module,
- decoder: AbsDecoder,
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ postencoder: Optional[AbsPostEncoder],
+ pooling_layer: torch.nn.Module,
+ decoder: AbsDecoder,
):
super().__init__()
@@ -71,11 +71,11 @@
self.decoder = decoder
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
@@ -87,10 +87,7 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
@@ -139,9 +136,7 @@
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
+ loss_ctc = (1 - self.interctc_weight) * loss_ctc + self.interctc_weight * loss_interctc
if self.use_transducer_decoder:
# 2a. Transducer decoder branch
@@ -196,11 +191,11 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
if self.extract_feats_in_collect_stats:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
@@ -215,7 +210,7 @@
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
@@ -244,14 +239,12 @@
# Post-encoder, e.g. NLU
if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
+ encoder_out, encoder_out_lens = self.postencoder(encoder_out, encoder_out_lens)
return encoder_out, encoder_out_lens
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
@@ -267,4 +260,4 @@
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
\ No newline at end of file
+ return feats, feats_lengths
--
Gitblit v1.9.1