| funasr/models/sense_voice/whisper_lib/decoding.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/sense_voice/whisper_lib/tokenizer.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -58,18 +58,20 @@ # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] if x is None: x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1] logits = model.logits(x, mel)[:, 0] logits = model.logits(x[:,:-1], mel)[:, -1] # collect detected languages; suppress all non-language tokens mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask[list(tokenizer.all_language_tokens)] = False mask[tokenizer.no_speech] = False logits[:, mask] = -np.inf language_tokens = logits.argmax(dim=-1) language_token_probs = logits.softmax(dim=-1).cpu() language_probs = [ { c: language_token_probs[i, j].item() for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"]) } for i in range(n_audio) ] funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -179,7 +179,12 @@ langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) if self.language == 'nospeech': sot_sequence.append(self.no_speech) else: sot_sequence.append(sot + 1 + langs.index(self.language)) # if self.language is not None: # sot_sequence.append(sot + 1 + langs.index(self.language)) if self.task is not None: task_token: int = transcribe if self.task == "transcribe" else translate sot_sequence.append(task_token) @@ -432,6 +437,8 @@ if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language] elif language == 'nospeech': pass else: raise ValueError(f"Unsupported language: {language}")