From c5992ca03ea6d6c7b78e5c1d481a612d0f91ac21 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 27 四月 2023 01:47:46 +0800
Subject: [PATCH] Update asr_inference_paraformer_streaming.py

---
 funasr/bin/asr_inference_paraformer_streaming.py |    8 ++++----
 1 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 63ef3a3..3f13982 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -203,7 +203,7 @@
         results = []
         cache_en = cache["encoder"]
         if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
-            cache_en["last_chunk"] = True
+            cache_en["tail_chunk"] = True
             feats = cache_en["feats"]
             feats_len = torch.tensor([feats.shape[1]])
         else:
@@ -232,7 +232,7 @@
                         feats_len = torch.tensor([feats_chunk2.shape[1]])
                         results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
 
-                        return results_chunk1 + results_chunk2
+                        return ["".join(results_chunk1 + results_chunk2)]
 
                 results = self.infer(feats, feats_len, cache)
 
@@ -466,7 +466,7 @@
 
         cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                     "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
         cache["encoder"] = cache_en
 
         cache_de = {"decode_fsmn": None}
@@ -478,7 +478,7 @@
         if len(cache) > 0:
             cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
             cache["encoder"] = cache_en
 
             cache_de = {"decode_fsmn": None}

--
Gitblit v1.9.1