From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/whisper_lib/tokenizer.py |  105 +++++++++++++++++++++++++++++++++++++++++-----------
 1 files changed, 83 insertions(+), 22 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index 2af8375..5b276c2 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -7,6 +7,7 @@
 
 import tiktoken
 
+# FIX(funasr): sense vocie
 LANGUAGES = {
     "en": "english",
     "zh": "chinese",
@@ -108,6 +109,11 @@
     "jw": "javanese",
     "su": "sundanese",
     "yue": "cantonese",
+    "minnan": "minnan",
+    "wuyu": "wuyu",
+    "dialect": "dialect",
+    "zh/en": "zh/en",
+    "en/zh": "en/zh",
 }
 
 # language code lookup by name, with a few language aliases
@@ -125,6 +131,28 @@
     "sinhalese": "si",
     "castilian": "es",
     "mandarin": "zh",
+}
+
+# FIX(funasr): sense vocie
+AUDIO_EVENT = {
+    "ASR": "ASR",
+    "AED": "AED",
+    "SER": "SER",
+    "Speech": "Speech",
+    "/Speech": "/Speech",
+    "BGM": "BGM",
+    "/BGM": "/BGM",
+    "Laughter": "Laughter",
+    "/Laughter": "/Laughter",
+    "Applause": "Applause",
+    "/Applause": "/Applause",
+}
+
+EMOTION = {
+    "HAPPY": "HAPPY",
+    "SAD": "SAD",
+    "ANGRY": "ANGRY",
+    "NEUTRAL": "NEUTRAL",
 }
 
 
@@ -151,7 +179,12 @@
         langs = tuple(LANGUAGES.keys())[: self.num_languages]
         sot_sequence = [sot]
         if self.language is not None:
-            sot_sequence.append(sot + 1 + langs.index(self.language))
+            if self.language == "nospeech":
+                sot_sequence.append(self.no_speech)
+            else:
+                sot_sequence.append(sot + 1 + langs.index(self.language))
+        # if self.language is not None:
+        #     sot_sequence.append(sot + 1 + langs.index(self.language))
         if self.task is not None:
             task_token: int = transcribe if self.task == "transcribe" else translate
             sot_sequence.append(task_token)
@@ -172,6 +205,9 @@
         """
         return self.encoding.decode(token_ids, **kwargs)
 
+    def get_vocab_size(self) -> int:
+        return self.encoding.n_vocab
+
     @cached_property
     def eot(self) -> int:
         return self.encoding.eot_token
@@ -186,6 +222,10 @@
 
     @cached_property
     def sot(self) -> int:
+        return self.special_tokens["<|startoftranscript|>"]
+
+    @cached_property
+    def sot_sense(self) -> int:
         return self.special_tokens["<|startoftranscript|>"]
 
     @cached_property
@@ -251,9 +291,7 @@
         keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
         """
         symbols = list('"#()*+/:;<=>@[\\]^_`{|}~銆屻�嶃�庛��')
-        symbols += (
-            "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
-        )
+        symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
 
         # symbols that may be a single token or multiple tokens depending on the tokenizer.
         # In case they're multiple tokens, suppress the first token, which is safe because:
@@ -328,8 +366,10 @@
 
 
 @lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99):
-    vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path: str = None):
+    if vocab_path is None:
+        vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+
     ranks = {
         base64.b64decode(token): int(rank)
         for token, rank in (line.split() for line in open(vocab_path) if line)
@@ -337,18 +377,35 @@
     n_vocab = len(ranks)
     special_tokens = {}
 
-    specials = [
-        "<|endoftext|>",
-        "<|startoftranscript|>",
-        *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
-        "<|translate|>",
-        "<|transcribe|>",
-        "<|startoflm|>",
-        "<|startofprev|>",
-        "<|nospeech|>",
-        "<|notimestamps|>",
-        *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
-    ]
+    if False:  # name == "gpt2" or name == "multilingual":
+        specials = [
+            "<|endoftext|>",
+            "<|startoftranscript|>",
+            *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
+            "<|translate|>",
+            "<|transcribe|>",
+            "<|startoflm|>",
+            "<|startofprev|>",
+            "<|nospeech|>",
+            "<|notimestamps|>",
+            *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
+        ]
+    else:
+        specials = [
+            "<|endoftext|>",
+            "<|startoftranscript|>",
+            *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
+            *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
+            *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
+            "<|translate|>",
+            "<|transcribe|>",
+            "<|startoflm|>",
+            "<|startofprev|>",
+            "<|nospeech|>",
+            "<|notimestamps|>",
+            *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 51)],
+            *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
+        ]
 
     for token in specials:
         special_tokens[token] = n_vocab
@@ -370,12 +427,16 @@
     num_languages: int = 99,
     language: Optional[str] = None,
     task: Optional[str] = None,  # Literal["transcribe", "translate", None]
+    encoding_path: Optional[str] = None,
+    vocab_path: Optional[str] = None,
 ) -> Tokenizer:
     if language is not None:
         language = language.lower()
         if language not in LANGUAGES:
             if language in TO_LANGUAGE_CODE:
                 language = TO_LANGUAGE_CODE[language]
+            elif language == "nospeech":
+                pass
             else:
                 raise ValueError(f"Unsupported language: {language}")
 
@@ -387,9 +448,9 @@
         encoding_name = "gpt2"
         language = None
         task = None
+    if encoding_path is not None:
+        encoding_name = encoding_path
 
-    encoding = get_encoding(name=encoding_name, num_languages=num_languages)
+    encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
 
-    return Tokenizer(
-        encoding=encoding, num_languages=num_languages, language=language, task=task
-    )
+    return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)

--
Gitblit v1.9.1