hnluo
2023-04-27 c5992ca03ea6d6c7b78e5c1d481a612d0f91ac21
Update asr_inference_paraformer_streaming.py
1个文件已修改
8 ■■■■ 已修改文件
funasr/bin/asr_inference_paraformer_streaming.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_streaming.py
@@ -203,7 +203,7 @@
        results = []
        cache_en = cache["encoder"]
        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
            cache_en["last_chunk"] = True
            cache_en["tail_chunk"] = True
            feats = cache_en["feats"]
            feats_len = torch.tensor([feats.shape[1]])
        else:
@@ -232,7 +232,7 @@
                        feats_len = torch.tensor([feats_chunk2.shape[1]])
                        results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
                        return results_chunk1 + results_chunk2
                        return ["".join(results_chunk1 + results_chunk2)]
                results = self.infer(feats, feats_len, cache)
@@ -466,7 +466,7 @@
        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                    "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
        cache["encoder"] = cache_en
        cache_de = {"decode_fsmn": None}
@@ -478,7 +478,7 @@
        if len(cache) > 0:
            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                        "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
            cache["encoder"] = cache_en
            cache_de = {"decode_fsmn": None}