zhifu gao
2024-04-10 112c8e6eb7c7b5dd150787848a5c1d6111688ed7
修复无法预测nospeech标签的问题 (#1604)

Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com>
2个文件已修改
17 ■■■■ 已修改文件
funasr/models/sense_voice/whisper_lib/decoding.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/tokenizer.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -58,18 +58,20 @@
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]
    logits = model.logits(x[:,:-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
    mask[tokenizer.no_speech] = False
    logits[:, mask] = -np.inf
    language_tokens = logits.argmax(dim=-1)
    language_token_probs = logits.softmax(dim=-1).cpu()
    language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
        }
        for i in range(n_audio)
    ]
funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -179,7 +179,12 @@
        langs = tuple(LANGUAGES.keys())[: self.num_languages]
        sot_sequence = [sot]
        if self.language is not None:
            sot_sequence.append(sot + 1 + langs.index(self.language))
            if self.language == 'nospeech':
                sot_sequence.append(self.no_speech)
            else:
                sot_sequence.append(sot + 1 + langs.index(self.language))
        # if self.language is not None:
        #     sot_sequence.append(sot + 1 + langs.index(self.language))
        if self.task is not None:
            task_token: int = transcribe if self.task == "transcribe" else translate
            sot_sequence.append(task_token)
@@ -432,6 +437,8 @@
        if language not in LANGUAGES:
            if language in TO_LANGUAGE_CODE:
                language = TO_LANGUAGE_CODE[language]
            elif language == 'nospeech':
                pass
            else:
                raise ValueError(f"Unsupported language: {language}")