From 112c8e6eb7c7b5dd150787848a5c1d6111688ed7 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 10 四月 2024 09:55:21 +0800
Subject: [PATCH] 修复无法预测nospeech标签的问题 (#1604)

---
 funasr/models/sense_voice/whisper_lib/tokenizer.py |    9 ++++++++-
 funasr/models/sense_voice/whisper_lib/decoding.py  |    8 +++++---
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index caca114..b3fce7e 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -58,18 +58,20 @@
     # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
     if x is None:
         x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
-    logits = model.logits(x, mel)[:, 0]
-
+    logits = model.logits(x[:,:-1], mel)[:, -1]
     # collect detected languages; suppress all non-language tokens
     mask = torch.ones(logits.shape[-1], dtype=torch.bool)
     mask[list(tokenizer.all_language_tokens)] = False
+    mask[tokenizer.no_speech] = False
+    
     logits[:, mask] = -np.inf
     language_tokens = logits.argmax(dim=-1)
     language_token_probs = logits.softmax(dim=-1).cpu()
+
     language_probs = [
         {
             c: language_token_probs[i, j].item()
-            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
+            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
         }
         for i in range(n_audio)
     ]
diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
index 463ce83..4334a13 100644
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ b/funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -179,7 +179,12 @@
         langs = tuple(LANGUAGES.keys())[: self.num_languages]
         sot_sequence = [sot]
         if self.language is not None:
-            sot_sequence.append(sot + 1 + langs.index(self.language))
+            if self.language == 'nospeech':
+                sot_sequence.append(self.no_speech)
+            else:
+                sot_sequence.append(sot + 1 + langs.index(self.language))
+        # if self.language is not None:
+        #     sot_sequence.append(sot + 1 + langs.index(self.language))
         if self.task is not None:
             task_token: int = transcribe if self.task == "transcribe" else translate
             sot_sequence.append(task_token)
@@ -432,6 +437,8 @@
         if language not in LANGUAGES:
             if language in TO_LANGUAGE_CODE:
                 language = TO_LANGUAGE_CODE[language]
+            elif language == 'nospeech':
+                pass
             else:
                 raise ValueError(f"Unsupported language: {language}")
 

--
Gitblit v1.9.1