From f89111a45b73bc9eae6f161e0a5b8ee0464d58c3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 24 五月 2023 23:30:26 +0800
Subject: [PATCH] websocket unify wss_srv for online offline 2pass
---
funasr/runtime/python/websocket/wss_client_asr.py | 296 ++++++++++++++++++++++++++++++++
funasr/runtime/python/websocket/wss_srv_asr.py | 210 +++++++++++++++++++++++
2 files changed, 506 insertions(+), 0 deletions(-)
diff --git a/funasr/runtime/python/websocket/wss_client_asr.py b/funasr/runtime/python/websocket/wss_client_asr.py
new file mode 100644
index 0000000..586e0a4
--- /dev/null
+++ b/funasr/runtime/python/websocket/wss_client_asr.py
@@ -0,0 +1,296 @@
+# -*- encoding: utf-8 -*-
+import os
+import time
+import websockets,ssl
+import asyncio
+# import threading
+import argparse
+import json
+import traceback
+from multiprocessing import Process
+from funasr.fileio.datadir_writer import DatadirWriter
+
+import logging
+
+logging.basicConfig(level=logging.ERROR)
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+ type=str,
+ default="localhost",
+ required=False,
+ help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+ type=int,
+ default=10095,
+ required=False,
+ help="grpc server port")
+parser.add_argument("--chunk_size",
+ type=str,
+ default="5, 10, 5",
+ help="chunk")
+parser.add_argument("--chunk_interval",
+ type=int,
+ default=10,
+ help="chunk")
+parser.add_argument("--audio_in",
+ type=str,
+ default=None,
+ help="audio_in")
+parser.add_argument("--send_without_sleep",
+ action="store_true",
+ default=False,
+ help="if audio_in is set, send_without_sleep")
+parser.add_argument("--test_thread_num",
+ type=int,
+ default=1,
+ help="test_thread_num")
+parser.add_argument("--words_max_print",
+ type=int,
+ default=10000,
+ help="chunk")
+parser.add_argument("--output_dir",
+ type=str,
+ default=None,
+ help="output_dir")
+
+parser.add_argument("--ssl",
+ type=int,
+ default=1,
+ help="1 for ssl connect, 0 for no ssl")
+parser.add_argument("--mode",
+ type=str,
+ default="2pass",
+ help="offline, online, 2pass")
+
+args = parser.parse_args()
+args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
+print(args)
+# voices = asyncio.Queue()
+from queue import Queue
+voices = Queue()
+
+ibest_writer = None
+if args.output_dir is not None:
+ writer = DatadirWriter(args.output_dir)
+ ibest_writer = writer[f"1best_recog"]
+
+async def record_microphone():
+ is_finished = False
+ import pyaudio
+ #print("2")
+ global voices
+ FORMAT = pyaudio.paInt16
+ CHANNELS = 1
+ RATE = 16000
+ chunk_size = 60*args.chunk_size[1]/args.chunk_interval
+ CHUNK = int(RATE / 1000 * chunk_size)
+
+ p = pyaudio.PyAudio()
+
+ stream = p.open(format=FORMAT,
+ channels=CHANNELS,
+ rate=RATE,
+ input=True,
+ frames_per_buffer=CHUNK)
+
+ message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
+ voices.put(message)
+ while True:
+
+ data = stream.read(CHUNK)
+ message = data
+
+ voices.put(message)
+
+ await asyncio.sleep(0.005)
+
+async def record_from_scp(chunk_begin,chunk_size):
+ import wave
+ global voices
+ is_finished = False
+ if args.audio_in.endswith(".scp"):
+ f_scp = open(args.audio_in)
+ wavs = f_scp.readlines()
+ else:
+ wavs = [args.audio_in]
+ if chunk_size>0:
+ wavs=wavs[chunk_begin:chunk_begin+chunk_size]
+ for wav in wavs:
+ wav_splits = wav.strip().split()
+ wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
+ wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
+
+ # bytes_f = open(wav_path, "rb")
+ # bytes_data = bytes_f.read()
+ with wave.open(wav_path, "rb") as wav_file:
+ params = wav_file.getparams()
+ # header_length = wav_file.getheaders()[0][1]
+ # wav_file.setpos(header_length)
+ frames = wav_file.readframes(wav_file.getnframes())
+
+ audio_bytes = bytes(frames)
+ # stride = int(args.chunk_size/1000*16000*2)
+ stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
+ chunk_num = (len(audio_bytes)-1)//stride + 1
+ # print(stride)
+
+ # send first time
+ message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+ voices.put(message)
+ is_speaking = True
+ for i in range(chunk_num):
+
+ beg = i*stride
+ data = audio_bytes[beg:beg+stride]
+ message = data
+ voices.put(message)
+ if i == chunk_num-1:
+ is_speaking = False
+ message = json.dumps({"is_speaking": is_speaking})
+ voices.put(message)
+ # print("data_chunk: ", len(data_chunk))
+ # print(voices.qsize())
+ sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
+ await asyncio.sleep(sleep_duration)
+
+
+async def ws_send():
+ global voices
+ global websocket
+ print("started to sending data!")
+ while True:
+ while not voices.empty():
+ data = voices.get()
+ voices.task_done()
+ try:
+ await websocket.send(data)
+ except Exception as e:
+ print('Exception occurred:', e)
+ traceback.print_exc()
+ exit(0)
+ await asyncio.sleep(0.005)
+ await asyncio.sleep(0.005)
+
+
+
+async def message(id):
+ global websocket
+ text_print = ""
+ text_print_2pass_online = ""
+ text_print_2pass_offline = ""
+ while True:
+ try:
+ meg = await websocket.recv()
+ meg = json.loads(meg)
+ wav_name = meg.get("wav_name", "demo")
+ # print(wav_name)
+ text = meg["text"]
+ if ibest_writer is not None:
+ ibest_writer["text"][wav_name] = text
+
+ if meg["mode"] == "online":
+ text_print += "{}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid"+str(id)+": "+text_print)
+ elif meg["mode"] == "online":
+ text_print += "{}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid"+str(id)+": "+text_print)
+ else:
+ if meg["mode"] == "2pass-online":
+ text_print_2pass_online += "{}".format(text)
+ text_print = text_print_2pass_offline + text_print_2pass_online
+ else:
+ text_print_2pass_online = ""
+ text_print = text_print_2pass_offline + "{}".format(text)
+ text_print_2pass_offline += "{}".format(text)
+ text_print = text_print[-args.words_max_print:]
+ os.system('clear')
+ print("\rpid" + str(id) + ": " + text_print)
+
+ except Exception as e:
+ print("Exception:", e)
+ traceback.print_exc()
+ exit(0)
+
+async def print_messge():
+ global websocket
+ while True:
+ try:
+ meg = await websocket.recv()
+ meg = json.loads(meg)
+ print(meg)
+ except Exception as e:
+ print("Exception:", e)
+ traceback.print_exc()
+ exit(0)
+
+async def ws_client(id,chunk_begin,chunk_size):
+ global websocket
+ if args.ssl==1:
+ ssl_context = ssl.SSLContext()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ uri = "wss://{}:{}".format(args.host, args.port)
+ else:
+ uri = "ws://{}:{}".format(args.host, args.port)
+ ssl_context=None
+ print("connect to",uri)
+ async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
+ if args.audio_in is not None:
+ task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
+ else:
+ task = asyncio.create_task(record_microphone())
+ task2 = asyncio.create_task(ws_send())
+ task3 = asyncio.create_task(message(id))
+ await asyncio.gather(task, task2, task3)
+
+def one_thread(id,chunk_begin,chunk_size):
+ asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
+ asyncio.get_event_loop().run_forever()
+
+
+if __name__ == '__main__':
+ # for microphone
+ if args.audio_in is None:
+ p = Process(target=one_thread,args=(0, 0, 0))
+ p.start()
+ p.join()
+ print('end')
+ else:
+ # calculate the number of wavs for each preocess
+ if args.audio_in.endswith(".scp"):
+ f_scp = open(args.audio_in)
+ wavs = f_scp.readlines()
+ else:
+ wavs = [args.audio_in]
+ total_len=len(wavs)
+ if total_len>=args.test_thread_num:
+ chunk_size=int((total_len)/args.test_thread_num)
+ remain_wavs=total_len-chunk_size*args.test_thread_num
+ else:
+ chunk_size=1
+ remain_wavs=0
+
+ process_list = []
+ chunk_begin=0
+ for i in range(args.test_thread_num):
+ now_chunk_size= chunk_size
+ if remain_wavs>0:
+ now_chunk_size=chunk_size+1
+ remain_wavs=remain_wavs-1
+ # process i handle wavs at chunk_begin and size of now_chunk_size
+ p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
+ chunk_begin=chunk_begin+now_chunk_size
+ p.start()
+ process_list.append(p)
+
+ for i in process_list:
+ p.join()
+
+ print('end')
+
+
diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
new file mode 100644
index 0000000..71c97e6
--- /dev/null
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -0,0 +1,210 @@
+import asyncio
+import json
+import websockets
+import time
+import logging
+import tracemalloc
+import numpy as np
+import ssl
+from parse_args import args
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
+
+tracemalloc.start()
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+
+websocket_users = set()
+
+print("model loading")
+# asr
+inference_pipeline_asr = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.asr_model,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ model_revision=None)
+
+
+# vad
+inference_pipeline_vad = pipeline(
+ task=Tasks.voice_activity_detection,
+ model=args.vad_model,
+ model_revision=None,
+ output_dir=None,
+ batch_size=1,
+ mode='online',
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+)
+
+if args.punc_model != "":
+ inference_pipeline_punc = pipeline(
+ task=Tasks.punctuation,
+ model=args.punc_model,
+ model_revision="v1.0.2",
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ )
+else:
+ inference_pipeline_punc = None
+
+inference_pipeline_asr_online = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.asr_model_online,
+ ngpu=args.ngpu,
+ ncpu=args.ncpu,
+ model_revision='v1.0.4')
+
+print("model loaded")
+
+async def ws_serve(websocket, path):
+ frames = []
+ frames_asr = []
+ frames_asr_online = []
+ global websocket_users
+ websocket_users.add(websocket)
+ websocket.param_dict_asr = {}
+ websocket.param_dict_asr_online = {"cache": dict()}
+ websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
+ websocket.param_dict_punc = {'cache': list()}
+ websocket.vad_pre_idx = 0
+ speech_start = False
+ speech_end_i = False
+ websocket.wav_name = "microphone"
+ websocket.mode = "2pass"
+ print("new user connected", flush=True)
+
+ try:
+ async for message in websocket:
+ if isinstance(message, str):
+ messagejson = json.loads(message)
+
+ if "is_speaking" in messagejson:
+ websocket.is_speaking = messagejson["is_speaking"]
+ websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+ if "chunk_interval" in messagejson:
+ websocket.chunk_interval = messagejson["chunk_interval"]
+ if "wav_name" in messagejson:
+ websocket.wav_name = messagejson.get("wav_name")
+ if "chunk_size" in messagejson:
+ websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+ if "mode" in messagejson:
+ websocket.mode = messagejson["mode"]
+ if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
+ if not isinstance(message, str):
+ frames.append(message)
+ duration_ms = len(message)//32
+ websocket.vad_pre_idx += duration_ms
+
+ # asr online
+ frames_asr_online.append(message)
+ websocket.param_dict_asr_online["is_final"] = speech_end_i
+ if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
+ if websocket.mode == "2pass" or websocket.mode == "online":
+ audio_in = b"".join(frames_asr_online)
+ await async_asr_online(websocket, audio_in)
+ frames_asr_online = []
+ if speech_start:
+ frames_asr.append(message)
+ # vad online
+ speech_start_i, speech_end_i = await async_vad(websocket, message)
+ if speech_start_i:
+ speech_start = True
+ beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+ frames_pre = frames[-beg_bias:]
+ frames_asr = []
+ frames_asr.extend(frames_pre)
+ # asr punc offline
+ if speech_end_i or not websocket.is_speaking:
+ # print("vad end point")
+ if websocket.mode == "2pass" or websocket.mode == "offline":
+ audio_in = b"".join(frames_asr)
+ await async_asr(websocket, audio_in)
+ frames_asr = []
+ speech_start = False
+ # frames_asr_online = []
+ # websocket.param_dict_asr_online = {"cache": dict()}
+ if not websocket.is_speaking:
+ websocket.vad_pre_idx = 0
+ frames = []
+ websocket.param_dict_vad = {'in_cache': dict()}
+ else:
+ frames = frames[-20:]
+
+
+ except websockets.ConnectionClosed:
+ print("ConnectionClosed...", websocket_users)
+ websocket_users.remove(websocket)
+ except websockets.InvalidState:
+ print("InvalidState...")
+ except Exception as e:
+ print("Exception:", e)
+
+
+async def async_vad(websocket, audio_in):
+
+ segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+
+ speech_start = False
+ speech_end = False
+
+ if len(segments_result) == 0 or len(segments_result["text"]) > 1:
+ return speech_start, speech_end
+ if segments_result["text"][0][0] != -1:
+ speech_start = segments_result["text"][0][0]
+ if segments_result["text"][0][1] != -1:
+ speech_end = True
+ return speech_start, speech_end
+
+
+async def async_asr(websocket, audio_in):
+ if len(audio_in) > 0:
+ # print(len(audio_in))
+ audio_in = load_bytes(audio_in)
+
+ rec_result = inference_pipeline_asr(audio_in=audio_in,
+ param_dict=websocket.param_dict_asr)
+ # print(rec_result)
+ if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
+ rec_result = inference_pipeline_punc(text_in=rec_result['text'],
+ param_dict=websocket.param_dict_punc)
+ # print("offline", rec_result)
+ if 'text' in rec_result:
+ message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
+
+async def async_asr_online(websocket, audio_in):
+ if len(audio_in) > 0:
+ audio_in = load_bytes(audio_in)
+ # print(websocket.param_dict_asr_online.get("is_final", False))
+ rec_result = inference_pipeline_asr_online(audio_in=audio_in,
+ param_dict=websocket.param_dict_asr_online)
+ # print(rec_result)
+ if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
+ return
+ # websocket.param_dict_asr_online["cache"] = dict()
+ if "text" in rec_result:
+ if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
+ # print("online", rec_result)
+ message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+ await websocket.send(message)
+
+if len(args.certfile)>0:
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+
+ # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+ ssl_cert = args.certfile
+ ssl_key = args.keyfile
+
+ ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+asyncio.get_event_loop().run_until_complete(start_server)
+asyncio.get_event_loop().run_forever()
\ No newline at end of file
--
Gitblit v1.9.1