From cb4bd0e6dee4d8bae49351c1ac3b099bcdd72acb Mon Sep 17 00:00:00 2001
From: 仁迷 <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 17 三月 2023 22:14:21 +0800
Subject: [PATCH] add paraformer streaming recipe
---
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 52 insertions(+), 0 deletions(-)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
new file mode 100644
index 0000000..f2024fb
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
@@ -0,0 +1,52 @@
+import torch
+import torchaudio
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
+ model_revision='v1.0.2')
+
+waveform, sample_rate = torchaudio.load("asr_example_zh.wav")
+speech_length = waveform.shape[1]
+speech = waveform[0]
+
+cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
+cache_de = {"decode_fsmn": None}
+cache = {"encoder": cache_en, "decoder": cache_de}
+param_dict = {}
+param_dict["cache"] = cache
+
+first_chunk = True
+speech_buffer = speech
+speech_cache = []
+final_result = ""
+
+while len(speech_buffer) > 0:
+ if first_chunk:
+ if len(speech_buffer) >= 14400:
+ rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict)
+ speech_buffer = speech_buffer[4800:]
+ else:
+ cache_en["stride"] = len(speech_buffer) // 960
+ cache_en["pad_right"] = 0
+ rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
+ speech_buffer = []
+ cache_en["start_idx"] = -5
+ first_chunk = False
+ else:
+ cache_en["start_idx"] += 10
+ if len(speech_buffer) >= 4800:
+ cache_en["pad_left"] = 5
+ rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict)
+ speech_buffer = speech_buffer[9600:]
+ else:
+ cache_en["stride"] = len(speech_buffer) // 960
+ cache_en["pad_right"] = 0
+ rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
+ speech_buffer = []
+ if rec_result['text'] != "sil":
+ final_result += rec_result['text']
+ print(rec_result)
+print(final_result)
--
Gitblit v1.9.1