From 8873c2c21a23a67e861fb2ae1672763ac709e7f6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 三月 2023 13:33:55 +0800
Subject: [PATCH] websocket

---
 funasr/runtime/python/websocket/ASR_server.py |   42 ++++++++++++++++++++++++++++++++++++------
 1 files changed, 36 insertions(+), 6 deletions(-)

diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py
index 0796a79..0af5208 100644
--- a/funasr/runtime/python/websocket/ASR_server.py
+++ b/funasr/runtime/python/websocket/ASR_server.py
@@ -6,19 +6,49 @@
 
 logger = get_logger(log_level=logging.CRITICAL)
 logger.setLevel(logging.CRITICAL)
+
 import asyncio
-import websockets  #鍖哄埆瀹㈡埛绔繖閲屾槸 websockets搴�
+import websockets
 import time
 from queue import Queue
 import threading
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="0.0.0.0",
+                    required=False,
+                    help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+                    type=int,
+                    default=10095,
+                    required=False,
+                    help="grpc server port")
+parser.add_argument("--asr_model",
+                    type=str,
+                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                    help="model from modelscope")
+parser.add_argument("--vad_model",
+                    type=str,
+                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                    help="model from modelscope")
+
+parser.add_argument("--punc_model",
+                    type=str,
+                    default="",
+                    help="model from modelscope")
+
+args = parser.parse_args()
 
 print("model loading")
 voices = Queue()
 speek = Queue()
+
 # 鍒涘缓涓�涓猇AD瀵硅薄
 vad_pipline = pipeline(
     task=Tasks.voice_activity_detection,
-    model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    model=args.vad_model,
     model_revision="v1.2.0",
     output_dir=None,
     batch_size=1,
@@ -26,17 +56,17 @@
   
 # 鍒涘缓涓�涓狝SR瀵硅薄
 param_dict = dict()
-param_dict["hotword"] = "灏忎簲 灏忎簲鏈�"  # 璁剧疆鐑瘝锛岀敤绌烘牸闅斿紑
+# param_dict["hotword"] = "灏忎簲 灏忎簲鏈�"  # 璁剧疆鐑瘝锛岀敤绌烘牸闅斿紑
 inference_pipeline2 = pipeline(
     task=Tasks.auto_speech_recognition,
-    model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+    model=args.asr_model,
     param_dict=param_dict,
 )
 print("model loaded")
 
 
 
-async def echo(websocket, path):
+async def ws_serve(websocket, path):
     global voices
     try:
         async for message in websocket:
@@ -47,7 +77,7 @@
     except Exception as e:
         print('Exception occurred:', e)
 
-start_server = websockets.serve(echo, "localhost", 8899, subprotocols=["binary"],ping_interval=None)
+start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 
 
 def vad(data):  # 鎺ㄧ悊

--
Gitblit v1.9.1