aky15
2023-07-04 3360a1d9453ef0ce441cc41b0090d09b3bb296bb
funasr/runtime/python/websocket/wss_srv_asr.py
@@ -5,8 +5,8 @@
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
from parse_args import args
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
@@ -16,6 +16,54 @@
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="0.0.0.0",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=10095,
                    required=False,
                    help="grpc server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="model from modelscope")
parser.add_argument("--asr_model_online",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
                    help="model from modelscope")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="model from modelscope")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                    help="model from modelscope")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
parser.add_argument("--ncpu",
                    type=int,
                    default=4,
                    help="cpu cores")
parser.add_argument("--certfile",
                    type=str,
                    default="./ssl_key/server.crt",
                    required=False,
                    help="certfile for ssl")
parser.add_argument("--keyfile",
                    type=str,
                    default="./ssl_key/server.key",
                    required=False,
                    help="keyfile for ssl")
args = parser.parse_args()
websocket_users = set()
@@ -35,8 +83,6 @@
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision=None,
    output_dir=None,
    batch_size=1,
    mode='online',
    ngpu=args.ngpu,
    ncpu=args.ncpu,
@@ -58,15 +104,36 @@
    model=args.asr_model_online,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision='v1.0.4')
    model_revision='v1.0.4',
    update_model='v1.0.4',
    mode='paraformer_streaming')
print("model loaded")
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
    print("ws reset now, total num is ",len(websocket_users))
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
    websocket.param_dict_asr_online["is_final"]=True
    # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
    # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
    await websocket.close()
async def clear_websocket():
   for websocket in websocket_users:
       await ws_reset(websocket)
   websocket_users.clear()
async def ws_serve(websocket, path):
    frames = []
    frames_asr = []
    frames_asr_online = []
    global websocket_users
    await clear_websocket()
    websocket_users.add(websocket)
    websocket.param_dict_asr = {}
    websocket.param_dict_asr_online = {"cache": dict()}
@@ -74,7 +141,7 @@
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
    speech_end_i = False
    speech_end_i = -1
    websocket.wav_name = "microphone"
    websocket.mode = "2pass"
    print("new user connected", flush=True)
@@ -103,7 +170,7 @@
        
                    # asr online
                    frames_asr_online.append(message)
                    websocket.param_dict_asr_online["is_final"] = speech_end_i
                    websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
                        if websocket.mode == "2pass" or websocket.mode == "online":
                            audio_in = b"".join(frames_asr_online)
@@ -113,14 +180,14 @@
                        frames_asr.append(message)
                    # vad online
                    speech_start_i, speech_end_i = await async_vad(websocket, message)
                    if speech_start_i:
                    if speech_start_i != -1:
                        speech_start = True
                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                        frames_pre = frames[-beg_bias:]
                        frames_asr = []
                        frames_asr.extend(frames_pre)
                # asr punc offline
                if speech_end_i or not websocket.is_speaking:
                if speech_end_i != -1 or not websocket.is_speaking:
                    # print("vad end point")
                    if websocket.mode == "2pass" or websocket.mode == "offline":
                        audio_in = b"".join(frames_asr)
@@ -138,7 +205,8 @@
     
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users)
        print("ConnectionClosed...", websocket_users,flush=True)
        await ws_reset(websocket)
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")
@@ -150,15 +218,15 @@
    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    speech_start = False
    speech_end = False
    speech_start = -1
    speech_end = -1
    
    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
        return speech_start, speech_end
    if segments_result["text"][0][0] != -1:
        speech_start = segments_result["text"][0][0]
    if segments_result["text"][0][1] != -1:
        speech_end = True
        speech_end = segments_result["text"][0][1]
    return speech_start, speech_end
@@ -207,4 +275,4 @@
else:
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
asyncio.get_event_loop().run_forever()