From bbda5496ffae1d9ab052e8736a8c0b080ea017f5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 三月 2024 16:28:00 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR merge
---
funasr/models/seaco_paraformer/model.py | 7 ++++---
funasr/models/contextual_paraformer/model.py | 6 ++----
examples/industrial_data_pretraining/seaco_paraformer/finetune.sh | 2 +-
3 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
index cfdec77..5614f44 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/finetune.sh
@@ -10,7 +10,7 @@
## option 1, download model automatically
model_name_or_model_dir="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-model_revision="v2.0.4"
+model_revision="v2.0.7"
## option 2, download model by git
#local_path_root=${workspace}/modelscope_models
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 9968bf2..b9fd3c4 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -94,10 +94,8 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
+ text_lengths = text_lengths.squeeze()
+ speech_lengths = speech_lengths.squeeze()
batch_size = speech.shape[0]
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 27ff5d1..21b6aba 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -117,6 +117,8 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
+ text_lengths = text_lengths.squeeze()
+ speech_lengths = speech_lengths.squeeze()
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
@@ -164,7 +166,7 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+ batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -190,8 +192,7 @@
# predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
# decoder forward
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
selected = self._hotword_representation(hotword_pad,
--
Gitblit v1.9.1