From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http  client

---
 funasr/models/sense_voice/whisper_lib/tokenizer.py |   20 ++++++++++++++++----
 1 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index e941fb2..4334a13 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -179,7 +179,12 @@
         langs = tuple(LANGUAGES.keys())[: self.num_languages]
         sot_sequence = [sot]
         if self.language is not None:
-            sot_sequence.append(sot + 1 + langs.index(self.language))
+            if self.language == 'nospeech':
+                sot_sequence.append(self.no_speech)
+            else:
+                sot_sequence.append(sot + 1 + langs.index(self.language))
+        # if self.language is not None:
+        #     sot_sequence.append(sot + 1 + langs.index(self.language))
         if self.task is not None:
             task_token: int = transcribe if self.task == "transcribe" else translate
             sot_sequence.append(task_token)
@@ -363,8 +368,10 @@
 
 
 @lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99):
-    vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path:str=None):
+    if vocab_path is None:
+        vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
+
     ranks = {
         base64.b64decode(token): int(rank)
         for token, rank in (line.split() for line in open(vocab_path) if line)
@@ -423,12 +430,15 @@
     language: Optional[str] = None,
     task: Optional[str] = None,  # Literal["transcribe", "translate", None]
     encoding_path: Optional[str] = None,
+    vocab_path: Optional[str] = None,
 ) -> Tokenizer:
     if language is not None:
         language = language.lower()
         if language not in LANGUAGES:
             if language in TO_LANGUAGE_CODE:
                 language = TO_LANGUAGE_CODE[language]
+            elif language == 'nospeech':
+                pass
             else:
                 raise ValueError(f"Unsupported language: {language}")
 
@@ -443,7 +453,9 @@
     if encoding_path is not None:
         encoding_name = encoding_path
 
-    encoding = get_encoding(name=encoding_name, num_languages=num_languages)
+
+    encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
+
 
     return Tokenizer(
         encoding=encoding, num_languages=num_languages, language=language, task=task

--
Gitblit v1.9.1