游雁
2024-05-07 fb0da9f849a5d3bd473dcdbaf6197c6a5ff24a57
decoding key
3个文件已修改
13 ■■■■■ 已修改文件
funasr/models/sense_voice/decoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/decoding.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py
@@ -472,7 +472,7 @@
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
funasr/models/sense_voice/model.py
@@ -806,7 +806,6 @@
            if len(kwargs.get("data_type", [])) > 1:
                audio_sample_list, text_token_int_list = audio_sample_list
                text_token_int = text_token_int_list[0]
                text_token_int = tokenizer.encode(text_token_int)
            else:
                text_token_int = None
@@ -846,7 +845,7 @@
        )
        if text_token_int is not None:
            i = 1
            i = 0
            results = []
            ibest_writer = None
            if kwargs.get("output_dir") is not None:
@@ -855,7 +854,9 @@
                ibest_writer = self.writer[f"1best_recog"]
            # 1. Forward decoder
            ys_pad = torch.tensor(text_token_int, dtype=torch.int64).to(kwargs["device"])[None, :]
            ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
                None, :
            ]
            ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to(
                kwargs["device"]
            )[None, :]
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -62,8 +62,10 @@
    else:
        x = x.to(mel.device)
    # FIX(funasr): sense vocie
    # logits = model.logits(x[:, :-1], mel)[:, -1]
    logits = model.logits(x[:, :], mel)[:, -1]
    logits = model.logits(x[:, :-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False