From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/whisper_lib/tokenizer.py | 105 +++++++++++++++++++++++++++++++++++++++++-----------
1 files changed, 83 insertions(+), 22 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index 2af8375..5b276c2 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -7,6 +7,7 @@
import tiktoken
+# FIX(funasr): sense vocie
LANGUAGES = {
"en": "english",
"zh": "chinese",
@@ -108,6 +109,11 @@
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
+ "minnan": "minnan",
+ "wuyu": "wuyu",
+ "dialect": "dialect",
+ "zh/en": "zh/en",
+ "en/zh": "en/zh",
}
# language code lookup by name, with a few language aliases
@@ -125,6 +131,28 @@
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
+}
+
+# FIX(funasr): sense vocie
+AUDIO_EVENT = {
+ "ASR": "ASR",
+ "AED": "AED",
+ "SER": "SER",
+ "Speech": "Speech",
+ "/Speech": "/Speech",
+ "BGM": "BGM",
+ "/BGM": "/BGM",
+ "Laughter": "Laughter",
+ "/Laughter": "/Laughter",
+ "Applause": "Applause",
+ "/Applause": "/Applause",
+}
+
+EMOTION = {
+ "HAPPY": "HAPPY",
+ "SAD": "SAD",
+ "ANGRY": "ANGRY",
+ "NEUTRAL": "NEUTRAL",
}
@@ -151,7 +179,12 @@
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
- sot_sequence.append(sot + 1 + langs.index(self.language))
+ if self.language == "nospeech":
+ sot_sequence.append(self.no_speech)
+ else:
+ sot_sequence.append(sot + 1 + langs.index(self.language))
+ # if self.language is not None:
+ # sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
@@ -172,6 +205,9 @@
"""
return self.encoding.decode(token_ids, **kwargs)
+ def get_vocab_size(self) -> int:
+ return self.encoding.n_vocab
+
@cached_property
def eot(self) -> int:
return self.encoding.eot_token
@@ -186,6 +222,10 @@
@cached_property
def sot(self) -> int:
+ return self.special_tokens["<|startoftranscript|>"]
+
+ @cached_property
+ def sot_sense(self) -> int:
return self.special_tokens["<|startoftranscript|>"]
@cached_property
@@ -251,9 +291,7 @@
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~銆屻�嶃�庛��')
- symbols += (
- "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
- )
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
@@ -328,8 +366,10 @@
@lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99):
- vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path: str = None):
+ if vocab_path is None:
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
@@ -337,18 +377,35 @@
n_vocab = len(ranks)
special_tokens = {}
- specials = [
- "<|endoftext|>",
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
- ]
+ if False: # name == "gpt2" or name == "multilingual":
+ specials = [
+ "<|endoftext|>",
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nospeech|>",
+ "<|notimestamps|>",
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
+ ]
+ else:
+ specials = [
+ "<|endoftext|>",
+ "<|startoftranscript|>",
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
+ *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
+ *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
+ "<|translate|>",
+ "<|transcribe|>",
+ "<|startoflm|>",
+ "<|startofprev|>",
+ "<|nospeech|>",
+ "<|notimestamps|>",
+ *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 51)],
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
+ ]
for token in specials:
special_tokens[token] = n_vocab
@@ -370,12 +427,16 @@
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
+ encoding_path: Optional[str] = None,
+ vocab_path: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
+ elif language == "nospeech":
+ pass
else:
raise ValueError(f"Unsupported language: {language}")
@@ -387,9 +448,9 @@
encoding_name = "gpt2"
language = None
task = None
+ if encoding_path is not None:
+ encoding_name = encoding_path
- encoding = get_encoding(name=encoding_name, num_languages=num_languages)
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
- return Tokenizer(
- encoding=encoding, num_languages=num_languages, language=language, task=task
- )
+ return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
--
Gitblit v1.9.1