From dcb92f13eddbf3032ce363b35f13f80afa8f94d1 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 14 九月 2023 16:46:30 +0800
Subject: [PATCH] add paraformer online opt infer code

---
 funasr/bin/asr_inference_launch.py |    8 ++++----
 1 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1b38f8f..e6049e9 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -853,7 +853,7 @@
                     "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
         cache["encoder"] = cache_en
 
-        cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None}
+        cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
         cache["decoder"] = cache_de
 
         return cache
@@ -870,7 +870,7 @@
                         "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
             cache["encoder"] = cache_en
 
-            cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None}
+            cache_de = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, "chunk_size": chunk_size}
             cache["decoder"] = cache_de
 
         return cache
@@ -982,8 +982,8 @@
 
         asr_result_list.append(item)
         if is_final:
-            cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1,
-                                 encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back)
+            cache = _cache_reset(cache, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, 
+                                 decoder_chunk_look_back=decoder_chunk_look_back, batch_size=1)
         return asr_result_list
 
     return _forward

--
Gitblit v1.9.1