From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train

---
 funasr/models/e2e_asr_paraformer.py |   34 ++++++++++++++++++----------------
 1 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 9d4f106..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -55,9 +55,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -78,6 +76,8 @@
             predictor_bias: int = 0,
             sampling_ratio: float = 0.2,
             share_embedding: bool = False,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -326,9 +326,8 @@
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
-                                                                                       encoder_out_mask,
-                                                                                       ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
+                                                                                  ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -659,6 +658,10 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -732,9 +735,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -757,6 +758,8 @@
             embeds_id: int = 2,
             embeds_loss_weight: float = 0.0,
             embed_dims: int = 768,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -906,9 +909,9 @@
         self.step_cur += 1
         # for data-parallel
         text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max(), :]
+        speech = speech[:, :speech_lengths.max()]
         if embed is not None:
-            embed = embed[:, :embed_lengths.max(), :]
+            embed = embed[:, :embed_lengths.max()]
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1008,9 +1011,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -1030,6 +1031,8 @@
             predictor_weight: float = 0.0,
             predictor_bias: int = 0,
             sampling_ratio: float = 0.2,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1097,9 +1100,8 @@
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
-                                                                                     encoder_out_mask,
-                                                                                     ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+                                                                                  ignore_id=self.ignore_id)
 
         # 0. sampler
         decoder_out_1st = None
@@ -1277,9 +1279,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -1309,6 +1309,8 @@
             bias_encoder_type: str = 'lstm',
             label_bracket: bool = False,
             use_decoder_embedding: bool = False,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight

--
Gitblit v1.9.1