From b006d268cd52a6febdae7344ecb183b510391ae5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 三月 2023 13:46:49 +0800
Subject: [PATCH] websocket

---
 funasr/runtime/python/websocket/ASR_server.py |   46 ++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 38 insertions(+), 8 deletions(-)

diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py
index 3627d3a..0af5208 100644
--- a/funasr/runtime/python/websocket/ASR_server.py
+++ b/funasr/runtime/python/websocket/ASR_server.py
@@ -6,19 +6,49 @@
 
 logger = get_logger(log_level=logging.CRITICAL)
 logger.setLevel(logging.CRITICAL)
+
 import asyncio
-import websockets  #鍖哄埆瀹㈡埛绔繖閲屾槸 websockets搴�
+import websockets
 import time
 from queue import Queue
-import  threading
+import threading
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="0.0.0.0",
+                    required=False,
+                    help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+                    type=int,
+                    default=10095,
+                    required=False,
+                    help="grpc server port")
+parser.add_argument("--asr_model",
+                    type=str,
+                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                    help="model from modelscope")
+parser.add_argument("--vad_model",
+                    type=str,
+                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                    help="model from modelscope")
+
+parser.add_argument("--punc_model",
+                    type=str,
+                    default="",
+                    help="model from modelscope")
+
+args = parser.parse_args()
 
 print("model loading")
 voices = Queue()
 speek = Queue()
+
 # 鍒涘缓涓�涓猇AD瀵硅薄
 vad_pipline = pipeline(
     task=Tasks.voice_activity_detection,
-    model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    model=args.vad_model,
     model_revision="v1.2.0",
     output_dir=None,
     batch_size=1,
@@ -26,17 +56,17 @@
   
 # 鍒涘缓涓�涓狝SR瀵硅薄
 param_dict = dict()
-param_dict["hotword"] = "灏忎簲 灏忎簲鏈�"  # 璁剧疆鐑瘝锛岀敤绌烘牸闅斿紑
+# param_dict["hotword"] = "灏忎簲 灏忎簲鏈�"  # 璁剧疆鐑瘝锛岀敤绌烘牸闅斿紑
 inference_pipeline2 = pipeline(
     task=Tasks.auto_speech_recognition,
-    model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+    model=args.asr_model,
     param_dict=param_dict,
 )
 print("model loaded")
 
 
 
-async def echo(websocket, path):
+async def ws_serve(websocket, path):
     global voices
     try:
         async for message in websocket:
@@ -47,7 +77,7 @@
     except Exception as e:
         print('Exception occurred:', e)
 
-start_server = websockets.serve(echo, "localhost", 8899, subprotocols=["binary"],ping_interval=None)
+start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 
 
 def vad(data):  # 鎺ㄧ悊
@@ -95,7 +125,7 @@
                 frames.append(data)
                 RECORD_NUM += 1    
             
-            if  vad(data):
+            if vad(data):
                 if not speech_detected:
                     print("妫�娴嬪埌浜哄0...")
                     speech_detected = True  # 鏍囪涓烘娴嬪埌璇煶

--
Gitblit v1.9.1