From d19f48e17478be273584853568ac101c994c37e5 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 08 四月 2024 18:51:53 +0800
Subject: [PATCH] Dev gzf exp (#1593)

---
 funasr/models/sense_voice/whisper_lib/decoding.py |   12 +++++++++---
 1 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 73b0262..caca114 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -119,6 +119,7 @@
 
     # FIX(funasr): sense vocie
     initial_prompt: str = None
+    vocab_path: str = None
 
 
 @dataclass(frozen=True)
@@ -527,6 +528,7 @@
             num_languages=model.num_languages,
             language=language,
             task=options.task,
+            vocab_path=options.vocab_path
         )
         self.tokenizer: Tokenizer = tokenizer
         self.options: DecodingOptions = self._verify_options(options)
@@ -616,10 +618,13 @@
                 + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                 + tokens
             )
-        #FIX(gzf): sense vocie
+        #FIX(funasr): sense vocie
         if initial_prompt := self.options.initial_prompt:
-            tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
-            if self.options.language is None:
+            if self.options.language is not None:
+                initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+            else:
+                tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
                 tokens += [0]
 
 
@@ -691,6 +696,7 @@
             if self.options.language is None:
                 # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                 languages = "".join([f"<|{language}|>" for language in languages])
+
                 n_audio = audio_features.shape[0]
                 lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
                     audio_features.device)  # [n_audio, 1]

--
Gitblit v1.9.1