From 702b9b540c3c1524748cd975a10ce33f0fa53912 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 30 三月 2024 11:54:51 +0800
Subject: [PATCH] sense voice (#1568)

---
 funasr/models/sense_voice/whisper_lib/decoding.py |   33 ++++++++++++++++++++++++++++-----
 1 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 49485d0..73b0262 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -17,7 +17,7 @@
 
 @torch.no_grad()
 def detect_language(
-    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
+    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None,
 ) -> Tuple[Tensor, List[dict]]:
     """
     Detect the spoken language in the audio, and return them as list of strings, along with the ids
@@ -48,12 +48,16 @@
         mel = mel.unsqueeze(0)
 
     # skip encoder forward pass if already-encoded audio features were given
-    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
+    # FIX(funasr): sense vocie
+    if mel.shape[-1] != model.dims.n_audio_state:
         mel = model.encoder(mel)
 
     # forward pass using a single token, startoftranscript
     n_audio = mel.shape[0]
-    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
+    # FIX(funasr): sense vocie
+    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
+    if x is None:
+        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
     logits = model.logits(x, mel)[:, 0]
 
     # collect detected languages; suppress all non-language tokens
@@ -112,6 +116,9 @@
 
     # implementation details
     fp16: bool = True  # use fp16 for most of the calculation
+
+    # FIX(funasr): sense vocie
+    initial_prompt: str = None
 
 
 @dataclass(frozen=True)
@@ -609,6 +616,12 @@
                 + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                 + tokens
             )
+        #FIX(gzf): sense vocie
+        if initial_prompt := self.options.initial_prompt:
+            tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+            if self.options.language is None:
+                tokens += [0]
+
 
         return tuple(tokens)
 
@@ -669,11 +682,21 @@
 
         if self.options.language is None or self.options.task == "lang_id":
             lang_tokens, lang_probs = self.model.detect_language(
-                audio_features, self.tokenizer
+                audio_features, self.tokenizer, x=tokens
             )
             languages = [max(probs, key=probs.get) for probs in lang_probs]
+            # FIX(funasr): sense vocie
+            # if self.options.language is None:
+                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
             if self.options.language is None:
-                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
+                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
+                languages = "".join([f"<|{language}|>" for language in languages])
+                n_audio = audio_features.shape[0]
+                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
+                    audio_features.device)  # [n_audio, 1]
+                
+                tokens[:, -1:] = lang_tokens[:, :]
+                languages = [languages]
 
         return languages, lang_probs
 

--
Gitblit v1.9.1