zhifu gao
2023-03-17 5ba0cedd531f6d35e98c3ab5a6a9b6eb5ae60f85
Merge pull request #261 from alibaba-damo-academy/tmp

add paraformer streaming recipe
1个文件已添加
52 ■■■■■ 已修改文件
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/asr/paraformer/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py
New file
@@ -0,0 +1,52 @@
import torch
import torchaudio
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
inference_pipeline = pipeline(
    task=Tasks.auto_speech_recognition,
    model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online',
    model_revision='v1.0.2')
waveform, sample_rate = torchaudio.load("asr_example_zh.wav")
speech_length = waveform.shape[1]
speech = waveform[0]
cache_en = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None}
cache_de = {"decode_fsmn": None}
cache = {"encoder": cache_en, "decoder": cache_de}
param_dict = {}
param_dict["cache"] = cache
first_chunk = True
speech_buffer = speech
speech_cache = []
final_result = ""
while len(speech_buffer) > 0:
    if first_chunk:
        if len(speech_buffer) >= 14400:
            rec_result = inference_pipeline(audio_in=speech_buffer[0:14400], param_dict=param_dict)
            speech_buffer = speech_buffer[4800:]
        else:
            cache_en["stride"] = len(speech_buffer) // 960
            cache_en["pad_right"] = 0
            rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
            speech_buffer = []
        cache_en["start_idx"] = -5
        first_chunk = False
    else:
        cache_en["start_idx"] += 10
        if len(speech_buffer) >= 4800:
            cache_en["pad_left"] = 5
            rec_result = inference_pipeline(audio_in=speech_buffer[:19200], param_dict=param_dict)
            speech_buffer = speech_buffer[9600:]
        else:
            cache_en["stride"] = len(speech_buffer) // 960
            cache_en["pad_right"] = 0
            rec_result = inference_pipeline(audio_in=speech_buffer, param_dict=param_dict)
            speech_buffer = []
    if rec_result['text'] != "sil":
        final_result += rec_result['text']
    print(rec_result)
print(final_result)