From 12496e559feea69af2e77eac6f22b32df3bf6762 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 一月 2024 23:21:12 +0800
Subject: [PATCH] streaming bugfix (#1271)
---
funasr/models/fsmn_vad_streaming/model.py | 6 ++++--
funasr/models/paraformer_streaming/model.py | 3 +--
examples/industrial_data_pretraining/paraformer_streaming/demo.py | 1 -
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 07efde6..6898030 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -10,7 +10,6 @@
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
-cache = {}
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 943cb47..7c21561 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -501,7 +501,9 @@
# self.AllResetDetection()
return segments
+
def init_cache(self, cache: dict = {}, **kwargs):
+
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
@@ -528,7 +530,7 @@
cache: dict = {},
**kwargs,
):
-
+
if len(cache) == 0:
self.init_cache(cache, **kwargs)
@@ -583,7 +585,7 @@
cache["prev_samples"] = audio_sample[:-m]
if _is_final:
- cache = {}
+ self.init_cache(cache)
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index bf45269..9bf5d39 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -502,8 +502,7 @@
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
-
-
+
if len(cache) == 0:
self.init_cache(cache, **kwargs)
--
Gitblit v1.9.1