| funasr/bin/punctuation_infer_vadrealtime.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/python/websocket/README.md | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/python/websocket/ws_client.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/python/websocket/ws_server_2pass.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/python/websocket/ws_server_online.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/bin/punctuation_infer_vadrealtime.py
@@ -61,7 +61,7 @@ text_name="text", non_linguistic_symbols=train_args.non_linguistic_symbols, ) print("start decoding!!!") @torch.no_grad() def __call__(self, text: Union[list, str], cache: list, split_size=20): funasr/runtime/python/websocket/README.md
@@ -33,11 +33,9 @@ #### ASR offline/online 2pass server [//]: # (```shell) [//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") [//]: # (```) ```shell python ws_server_2pass.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" ``` ## For the client @@ -49,6 +47,7 @@ ``` ### Start client #### ASR offline client ##### Recording from mircrophone ```shell @@ -60,6 +59,7 @@ # --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms python ws_client.py --host "0.0.0.0" --port 10095 --chunk_interval 10 --words_max_print 100 --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results" ``` #### ASR streaming client ##### Recording from mircrophone ```shell @@ -73,7 +73,16 @@ ``` #### ASR offline/online 2pass client ##### Recording from mircrophone ```shell # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4" --words_max_print 10000 ``` ##### Loadding from wav.scp(kaldi style) ```shell # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4" --audio_in "./data/wav.scp" --words_max_print 10000 --output_dir "./results" ``` ## Acknowledge 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). 2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service. funasr/runtime/python/websocket/ws_client.py
@@ -10,6 +10,10 @@ from multiprocessing import Process from funasr.fileio.datadir_writer import DatadirWriter import logging logging.basicConfig(level=logging.ERROR) parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, @@ -158,25 +162,40 @@ async def message(id): global websocket text_print = "" text_print_2pass_online = "" text_print_2pass_offline = "" while True: try: meg = await websocket.recv() meg = json.loads(meg) # print(meg, end = '') # print("\r") # print(meg) wav_name = meg.get("wav_name", "demo") print(wav_name) # print(wav_name) text = meg["text"] if ibest_writer is not None: ibest_writer["text"][wav_name] = text if meg["mode"] == "online": text_print += " {}".format(text) else: text_print += "{}".format(text) text_print = text_print[-args.words_max_print:] os.system('clear') print("\rpid"+str(id)+": "+text_print) elif meg["mode"] == "online": text_print += "{}".format(text) text_print = text_print[-args.words_max_print:] os.system('clear') print("\rpid"+str(id)+": "+text_print) else: if meg["mode"] == "2pass-online": text_print_2pass_online += " {}".format(text) text_print = text_print_2pass_offline + text_print_2pass_online else: text_print_2pass_online = " " text_print = text_print_2pass_offline + "{}".format(text) text_print_2pass_offline += "{}".format(text) text_print = text_print[-args.words_max_print:] os.system('clear') print("\rpid" + str(id) + ": " + text_print) except Exception as e: print("Exception:", e) traceback.print_exc() @@ -207,7 +226,7 @@ await asyncio.gather(task, task2, task3) def one_thread(id): asyncio.get_event_loop().run_until_complete(ws_client(id)) # 启动协程 asyncio.get_event_loop().run_until_complete(ws_client(id)) asyncio.get_event_loop().run_forever() funasr/runtime/python/websocket/ws_server_2pass.py
New file @@ -0,0 +1,182 @@ import asyncio import json import websockets import time import logging import tracemalloc import numpy as np from parse_args import args from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes tracemalloc.start() logger = get_logger(log_level=logging.CRITICAL) logger.setLevel(logging.CRITICAL) websocket_users = set() print("model loading") # asr inference_pipeline_asr = pipeline( task=Tasks.auto_speech_recognition, model=args.asr_model, ngpu=args.ngpu, ncpu=args.ncpu, model_revision=None) # vad inference_pipeline_vad = pipeline( task=Tasks.voice_activity_detection, model=args.vad_model, model_revision=None, output_dir=None, batch_size=1, mode='online', ngpu=args.ngpu, ncpu=args.ncpu, ) if args.punc_model != "": inference_pipeline_punc = pipeline( task=Tasks.punctuation, model=args.punc_model, model_revision=None, ngpu=args.ngpu, ncpu=args.ncpu, ) else: inference_pipeline_punc = None inference_pipeline_asr_online = pipeline( task=Tasks.auto_speech_recognition, model=args.asr_model_online, ngpu=args.ngpu, ncpu=args.ncpu, model_revision='v1.0.4') print("model loaded") async def ws_serve(websocket, path): frames = [] frames_asr = [] frames_asr_online = [] global websocket_users websocket_users.add(websocket) websocket.param_dict_asr = {} websocket.param_dict_asr_online = {"cache": dict()} websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} websocket.param_dict_punc = {'cache': list()} websocket.vad_pre_idx = 0 speech_start = False try: async for message in websocket: message = json.loads(message) is_finished = message["is_finished"] if not is_finished: audio = bytes(message['audio'], 'ISO-8859-1') frames.append(audio) duration_ms = len(audio)//32 websocket.vad_pre_idx += duration_ms is_speaking = message["is_speaking"] websocket.param_dict_vad["is_final"] = not is_speaking websocket.param_dict_asr_online["is_final"] = not is_speaking websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"] websocket.wav_name = message.get("wav_name", "demo") # asr online frames_asr_online.append(audio) if len(frames_asr_online) % message["chunk_interval"] == 0: audio_in = b"".join(frames_asr_online) await async_asr_online(websocket, audio_in) frames_asr_online = [] if speech_start: frames_asr.append(audio) # vad online speech_start_i, speech_end_i = await async_vad(websocket, audio) if speech_start_i: speech_start = True beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms frames_pre = frames[-beg_bias:] frames_asr = [] frames_asr.extend(frames_pre) # asr punc offline if speech_end_i or not is_speaking: audio_in = b"".join(frames_asr) await async_asr(websocket, audio_in) frames_asr = [] speech_start = False frames_asr_online = [] websocket.param_dict_asr_online = {"cache": dict()} if not is_speaking: websocket.vad_pre_idx = 0 frames = [] websocket.param_dict_vad = {'in_cache': dict()} else: frames = frames[-20:] except websockets.ConnectionClosed: print("ConnectionClosed...", websocket_users) websocket_users.remove(websocket) except websockets.InvalidState: print("InvalidState...") except Exception as e: print("Exception:", e) async def async_vad(websocket, audio_in): segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad) speech_start = False speech_end = False if len(segments_result) == 0 or len(segments_result["text"]) > 1: return speech_start, speech_end if segments_result["text"][0][0] != -1: speech_start = segments_result["text"][0][0] if segments_result["text"][0][1] != -1: speech_end = True return speech_start, speech_end async def async_asr(websocket, audio_in): if len(audio_in) > 0: # print(len(audio_in)) audio_in = load_bytes(audio_in) rec_result = inference_pipeline_asr(audio_in=audio_in, param_dict=websocket.param_dict_asr) # print(rec_result) if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0: rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) # print("offline", rec_result) message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name}) await websocket.send(message) async def async_asr_online(websocket, audio_in): if len(audio_in) > 0: audio_in = load_bytes(audio_in) rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) if websocket.param_dict_asr_online["is_final"]: websocket.param_dict_asr_online["cache"] = dict() if "text" in rec_result: if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": # print("online", rec_result) message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name}) await websocket.send(message) start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() funasr/runtime/python/websocket/ws_server_online.py
@@ -37,12 +37,10 @@ async def ws_serve(websocket, path): frames_online = [] frames_asr_online = [] global websocket_users websocket.send_msg = Queue() websocket_users.add(websocket) websocket.param_dict_asr_online = {"cache": dict()} websocket.speek_online = Queue() try: async for message in websocket: @@ -56,11 +54,11 @@ websocket.wav_name = message.get("wav_name", "demo") websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"] frames_online.append(audio) if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking: audio_in = b"".join(frames_online) frames_asr_online.append(audio) if len(frames_asr_online) % message["chunk_interval"] == 0 or not is_speaking: audio_in = b"".join(frames_asr_online) await async_asr_online(websocket,audio_in) frames_online = [] frames_asr_online = [] @@ -81,8 +79,6 @@ websocket.param_dict_asr_online["cache"] = dict() if "text" in rec_result: if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": # if len(rec_result["text"])>0: # rec_result["text"][0]=rec_result["text"][0] #.replace(" ","") message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name}) await websocket.send(message)