From 32e783664534bbb8d3b8ba64c2c2ecb42398eb00 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 09:54:35 +0800
Subject: [PATCH] update with main (#1786)
---
funasr/models/sense_voice/model.py | 28 +++++++++++++++++++++-------
1 files changed, 21 insertions(+), 7 deletions(-)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 127d5a0..22272ee 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1264,15 +1264,29 @@
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
- initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+
+ sos = kwargs.get("model_conf").get("sos")
+ if isinstance(sos, str):
+ initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
+ language = DecodingOptions.get("language", None)
+ language = None if language == "auto" else language
- sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
- sos_int = tokenizer.encode(sos, allowed_special="all")
+ sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ sos_int = tokenizer.encode(sos, allowed_special="all")
+ else:
+ language = DecodingOptions.get("language", None)
+ language = None if language == "auto" else language
+ initial_prompt = kwargs.get("initial_prompt", f"{task}")
+ initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+ initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
+ sos_int = [sos] + initial_prompt_lid_int
eos = kwargs.get("model_conf").get("eos")
- eos_int = tokenizer.encode(eos, allowed_special="all")
+ if isinstance(eos, str):
+ eos_int = tokenizer.encode(eos, allowed_special="all")
+ else:
+ eos_int = [eos]
+
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
@@ -1298,7 +1312,7 @@
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(
- speech[None, :, :].permute(0, 2, 1), speech_lengths
+ speech[None, :, :], speech_lengths
)
if text_token_int is not None:
--
Gitblit v1.9.1