From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/whisper_lib/tokenizer.py | 25 +++++++++++++------------
1 files changed, 13 insertions(+), 12 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index 463ce83..5b276c2 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -179,7 +179,12 @@
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
- sot_sequence.append(sot + 1 + langs.index(self.language))
+ if self.language == "nospeech":
+ sot_sequence.append(self.no_speech)
+ else:
+ sot_sequence.append(sot + 1 + langs.index(self.language))
+ # if self.language is not None:
+ # sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
@@ -199,7 +204,7 @@
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
return self.encoding.decode(token_ids, **kwargs)
-
+
def get_vocab_size(self) -> int:
return self.encoding.n_vocab
@@ -286,9 +291,7 @@
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~銆屻�嶃�庛��')
- symbols += (
- "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
- )
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@@ -363,7 +366,7 @@
@lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None):
+def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path: str = None):
if vocab_path is None:
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
@@ -374,7 +377,7 @@
n_vocab = len(ranks)
special_tokens = {}
- if False: #name == "gpt2" or name == "multilingual":
+ if False: # name == "gpt2" or name == "multilingual":
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
@@ -432,6 +435,8 @@
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
+ elif language == "nospeech":
+ pass
else:
raise ValueError(f"Unsupported language: {language}")
@@ -446,10 +451,6 @@
if encoding_path is not None:
encoding_name = encoding_path
-
encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
-
- return Tokenizer(
- encoding=encoding, num_languages=num_languages, language=language, task=task
- )
+ return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
--
Gitblit v1.9.1