From 851e3e3ef83d0769d9bde172d8841f6b20e3e377 Mon Sep 17 00:00:00 2001
From: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com>
Date: 星期三, 10 四月 2024 14:37:35 +0800
Subject: [PATCH] Gcf (#1605)
---
funasr/models/sense_voice/whisper_lib/decoding.py | 24 +++++++++++++++++++-----
1 files changed, 19 insertions(+), 5 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
index 73b0262..2239b64 100644
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ b/funasr/models/sense_voice/whisper_lib/decoding.py
@@ -10,6 +10,8 @@
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
+from funasr.models.transformer.utils.nets_utils import to_device
+
if TYPE_CHECKING:
from .model import Whisper
@@ -58,18 +60,24 @@
# x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
if x is None:
x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1]
- logits = model.logits(x, mel)[:, 0]
+ else:
+ x = x.to(mel.device)
+
+ logits = model.logits(x[:,:-1], mel)[:, -1]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
+ mask[tokenizer.no_speech] = False
+
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
+
language_probs = [
{
c: language_token_probs[i, j].item()
- for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
+ for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
}
for i in range(n_audio)
]
@@ -119,6 +127,7 @@
# FIX(funasr): sense vocie
initial_prompt: str = None
+ vocab_path: str = None
@dataclass(frozen=True)
@@ -527,6 +536,7 @@
num_languages=model.num_languages,
language=language,
task=options.task,
+ vocab_path=options.vocab_path
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
@@ -616,10 +626,13 @@
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
- #FIX(gzf): sense vocie
+ #FIX(funasr): sense vocie
if initial_prompt := self.options.initial_prompt:
- tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
- if self.options.language is None:
+ if self.options.language is not None:
+ initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
+ tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
+ else:
+ tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
tokens += [0]
@@ -691,6 +704,7 @@
if self.options.language is None:
# tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
languages = "".join([f"<|{language}|>" for language in languages])
+
n_audio = audio_features.shape[0]
lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
audio_features.device) # [n_audio, 1]
--
Gitblit v1.9.1