From 1ce704d8c09bd4d4c7e5ab087f951f31fad9fca6 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 07 七月 2023 15:47:19 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/runtime/python/websocket/funasr_wss_server.py |   58 +++++++++++++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 51 insertions(+), 7 deletions(-)

diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/funasr_wss_server.py
similarity index 80%
rename from funasr/runtime/python/websocket/wss_srv_asr.py
rename to funasr/runtime/python/websocket/funasr_wss_server.py
index 09f2305..4929090 100644
--- a/funasr/runtime/python/websocket/wss_srv_asr.py
+++ b/funasr/runtime/python/websocket/funasr_wss_server.py
@@ -5,17 +5,64 @@
 import logging
 import tracemalloc
 import numpy as np
+import argparse
 import ssl
-from parse_args import args
 from modelscope.pipelines import pipeline
 from modelscope.utils.constant import Tasks
 from modelscope.utils.logger import get_logger
-from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
 
 tracemalloc.start()
 
 logger = get_logger(log_level=logging.CRITICAL)
 logger.setLevel(logging.CRITICAL)
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="0.0.0.0",
+                    required=False,
+                    help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+                    type=int,
+                    default=10095,
+                    required=False,
+                    help="grpc server port")
+parser.add_argument("--asr_model",
+                    type=str,
+                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+                    help="model from modelscope")
+parser.add_argument("--asr_model_online",
+                    type=str,
+                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+                    help="model from modelscope")
+parser.add_argument("--vad_model",
+                    type=str,
+                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                    help="model from modelscope")
+parser.add_argument("--punc_model",
+                    type=str,
+                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
+                    help="model from modelscope")
+parser.add_argument("--ngpu",
+                    type=int,
+                    default=1,
+                    help="0 for cpu, 1 for gpu")
+parser.add_argument("--ncpu",
+                    type=int,
+                    default=4,
+                    help="cpu cores")
+parser.add_argument("--certfile",
+                    type=str,
+                    default="./ssl_key/server.crt",
+                    required=False,
+                    help="certfile for ssl")
+
+parser.add_argument("--keyfile",
+                    type=str,
+                    default="./ssl_key/server.key",
+                    required=False,
+                    help="keyfile for ssl")
+args = parser.parse_args()
 
 
 websocket_users = set()
@@ -185,8 +232,6 @@
 async def async_asr(websocket, audio_in):
             if len(audio_in) > 0:
                 # print(len(audio_in))
-                audio_in = load_bytes(audio_in)
-                
                 rec_result = inference_pipeline_asr(audio_in=audio_in,
                                                     param_dict=websocket.param_dict_asr)
                 # print(rec_result)
@@ -195,13 +240,12 @@
                                                          param_dict=websocket.param_dict_punc)
                     # print("offline", rec_result)
                 if 'text' in rec_result:
-                    message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
+                    message = json.dumps({"mode": websocket.mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                     await websocket.send(message)
 
 
 async def async_asr_online(websocket, audio_in):
     if len(audio_in) > 0:
-        audio_in = load_bytes(audio_in)
         # print(websocket.param_dict_asr_online.get("is_final", False))
         rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                    param_dict=websocket.param_dict_asr_online)
@@ -212,7 +256,7 @@
         if "text" in rec_result:
             if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                 # print("online", rec_result)
-                message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+                message = json.dumps({"mode": websocket.mode, "text": rec_result["text"], "wav_name": websocket.wav_name})
                 await websocket.send(message)
 
 if len(args.certfile)>0:

--
Gitblit v1.9.1