From bbda5496ffae1d9ab052e8736a8c0b080ea017f5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 16:28:00 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge

---
 funasr/models/seaco_paraformer/model.py                           |    7 ++++---
 funasr/models/contextual_paraformer/model.py                      |    6 ++----
 examples/industrial_data_pretraining/seaco_paraformer/finetune.sh |    2 +-
 3 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
index cfdec77..5614f44 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
@@ -10,7 +10,7 @@
 
 ## option 1, download model automatically
 model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.4"
+model_revision="v2.0.7"
 
 ## option 2, download model by git
 #local_path_root=${workspace}/modelscope_models
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 9968bf2..b9fd3c4 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -94,10 +94,8 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
-        if len(text_lengths.size()) > 1:
-            text_lengths = text_lengths[:, 0]
-        if len(speech_lengths.size()) > 1:
-            speech_lengths = speech_lengths[:, 0]
+        text_lengths = text_lengths.squeeze()
+        speech_lengths = speech_lengths.squeeze()
 
         batch_size = speech.shape[0]
 
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 27ff5d1..21b6aba 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -117,6 +117,8 @@
                 text: (Batch, Length)
                 text_lengths: (Batch,)
         """
+        text_lengths = text_lengths.squeeze()
+        speech_lengths = speech_lengths.squeeze()
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
@@ -164,7 +166,7 @@
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+            batch_size = (text_lengths + self.predictor_bias).sum()
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
@@ -190,8 +192,7 @@
         # predictor forward
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
         # decoder forward
         decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
         selected = self._hotword_representation(hotword_pad, 

--
Gitblit v1.9.1