From d929c8e0f7bf07e4ae5008fb9409a78fd4e551c7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 22 三月 2024 19:37:41 +0800
Subject: [PATCH] update
---
funasr/models/seaco_paraformer/model.py | 18 ++++++++++++++----
1 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 92fc989..21b6aba 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -117,6 +117,8 @@
text: (Batch, Length)
text_lengths: (Batch,)
"""
+ text_lengths = text_lengths.squeeze()
+ speech_lengths = speech_lengths.squeeze()
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
@@ -164,7 +166,7 @@
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
- batch_size = (text_lengths + self.predictor_bias).sum().type_as(batch_size)
+ batch_size = (text_lengths + self.predictor_bias).sum()
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
@@ -190,8 +192,7 @@
# predictor forward
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, _, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)[0]
# decoder forward
decoder_out, _ = self.decoder(encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_lengths, return_hidden=True)
selected = self._hotword_representation(hotword_pad,
@@ -430,7 +431,6 @@
return results, meta_data
-
def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
def load_seg_dict(seg_dict_file):
seg_dict = {}
@@ -532,3 +532,13 @@
hotword_list = None
return hotword_list
+ def export(
+ self,
+ **kwargs,
+ ):
+ if 'max_seq_len' not in kwargs:
+ kwargs['max_seq_len'] = 512
+ from .export_meta import export_rebuild_model
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
+
--
Gitblit v1.9.1