From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train

---
 funasr/models/e2e_asr_paraformer.py |   14 ++++++++------
 1 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index cc6f54b..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -326,9 +326,8 @@
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
-                                                                                       encoder_out_mask,
-                                                                                       ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
+                                                                                  ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
@@ -659,6 +658,10 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -1097,9 +1100,8 @@
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
-                                                                                     encoder_out_mask,
-                                                                                     ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+                                                                                  ignore_id=self.ignore_id)
 
         # 0. sampler
         decoder_out_1st = None

--
Gitblit v1.9.1