游雁
2024-05-07 78ff06a45cafdb7c093613cf7ed5c4a4cc26eda9
decoding key
1个文件已修改
37 ■■■■■ 已修改文件
funasr/models/sense_voice/model.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py
@@ -802,6 +802,14 @@
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            if len(kwargs.get("data_type", [])) > 1:
                audio_sample_list, text_token_int_list = audio_sample_list
                text_token_int = text_token_int_list[0]
                text_token_int = tokenizer.encode(text_token_int)
            else:
                text_token_int = None
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
@@ -837,6 +845,35 @@
            speech[None, :, :].permute(0, 2, 1), speech_lengths
        )
        if text_token_int is not None:
            i = 1
            results = []
            ibest_writer = None
            if kwargs.get("output_dir") is not None:
                if not hasattr(self, "writer"):
                    self.writer = DatadirWriter(kwargs.get("output_dir"))
                ibest_writer = self.writer[f"1best_recog"]
            # 1. Forward decoder
            ys_pad = torch.tensor(text_token_int, dtype=torch.int64).to(kwargs["device"])[None, :]
            ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to(
                kwargs["device"]
            )[None, :]
            decoder_out = self.model.decoder(
                x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
            )
            token_int = decoder_out.argmax(-1)[0, :].tolist()
            text = tokenizer.decode(token_int)
            result_i = {"key": key[i], "text": text}
            results.append(result_i)
            if ibest_writer is not None:
                # ibest_writer["token"][key[i]] = " ".join(token)
                ibest_writer["text"][key[i]] = text
            return results, meta_data
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0],