From 420deca6146c433f9521b43a57d1f10200678db2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 五月 2023 20:08:57 +0800
Subject: [PATCH] websocket docs

---
 /dev/null                                 |  147 ------------------------------------
 funasr/runtime/python/websocket/README.md |   55 +++----------
 2 files changed, 12 insertions(+), 190 deletions(-)

diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md
index f489bac..e19c00c 100644
--- a/funasr/runtime/python/websocket/README.md
+++ b/funasr/runtime/python/websocket/README.md
@@ -21,43 +21,10 @@
 ```
 
 ### Start server
-#### ASR offline server
-##### API-reference
-```shell
-python ws_server_offline.py \
---port [port id] \
---asr_model [asr model_name] \
---punc_model [punc model_name] \
---ngpu [0 or 1] \
---ncpu [1 or 4] \
---certfile [path of certfile for ssl] \
---keyfile [path of keyfile for ssl] 
-```
-##### Usage examples
-```shell
-python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
-```
 
-#### ASR streaming server
 ##### API-reference
 ```shell
-python ws_server_online.py \
---port [port id] \
---asr_model_online [asr model_name] \
---ngpu [0 or 1] \
---ncpu [1 or 4] \
---certfile [path of certfile for ssl] \
---keyfile [path of keyfile for ssl] 
-```
-##### Usage examples
-```shell
-python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
-```
-
-#### ASR offline/online 2pass server
-##### API-reference
-```shell
-python ws_server_2pass.py \
+python wss_srv_asr.py \
 --port [port id] \
 --asr_model [asr model_name] \
 --asr_model_online [asr model_name] \
@@ -69,7 +36,7 @@
 ```
 ##### Usage examples
 ```shell
-python ws_server_2pass.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"  --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+python wss_srv_asr.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"  --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
 ```
 
 ## For the client
@@ -84,7 +51,7 @@
 ### Start client
 #### API-reference
 ```shell
-python ws_client.py \
+python wss_client_asr.py \
 --host [ip_address] \
 --port [port id] \
 --chunk_size ["5,10,5"=600ms, "8,8,4"=480ms] \
@@ -93,43 +60,45 @@
 --audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
 --output_dir [if set, write the results to output_dir] \
 --send_without_sleep [only set for offline] \
---ssl [1 for wss connect, 0 for ws, default is 1]
+--ssl [1 for wss connect, 0 for ws, default is 1] \
+--mode [`online` for streaming asr, `offline` for non-streaming, `2pass` for unifying streaming and non-streaming asr] \
 ```
+
 #### Usage examples
 ##### ASR offline client
 Recording from mircrophone
 ```shell
 # --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_interval 10 --words_max_print 100
+python ws_client.py --host "0.0.0.0" --port 10095 --mode offline --chunk_interval 10 --words_max_print 100
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_interval, "10": 600/10=60ms, "5"=600/5=120ms, "20": 600/12=30ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_interval 10 --words_max_print 100 --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
+python ws_client.py --host "0.0.0.0" --port 10095 --mode offline --chunk_interval 10 --words_max_print 100 --audio_in "./data/wav.scp" --send_without_sleep --output_dir "./results"
 ```
 
 ##### ASR streaming client
 Recording from mircrophone
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "5,10,5" --words_max_print 100
+python ws_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5" --words_max_print 100
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "5,10,5" --audio_in "./data/wav.scp" --output_dir "./results"
+python ws_client.py --host "0.0.0.0" --port 10095 --mode online --chunk_size "5,10,5" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 
 ##### ASR offline/online 2pass client
 Recording from mircrophone
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4"
+python ws_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4"
 ```
 Loadding from wav.scp(kaldi style)
 ```shell
 # --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
-python ws_client.py --host "0.0.0.0" --port 10095 --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+python ws_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 ## Acknowledge
 1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
deleted file mode 100644
index f7dfcaf..0000000
--- a/funasr/runtime/python/websocket/ws_client.py
+++ /dev/null
@@ -1,292 +0,0 @@
-# -*- encoding: utf-8 -*-
-import os
-import time
-import websockets,ssl
-import asyncio
-# import threading
-import argparse
-import json
-import traceback
-from multiprocessing import Process
-from funasr.fileio.datadir_writer import DatadirWriter
-
-import logging
-
-logging.basicConfig(level=logging.ERROR)
-
-parser = argparse.ArgumentParser()
-parser.add_argument("--host",
-                    type=str,
-                    default="localhost",
-                    required=False,
-                    help="host ip, localhost, 0.0.0.0")
-parser.add_argument("--port",
-                    type=int,
-                    default=10095,
-                    required=False,
-                    help="grpc server port")
-parser.add_argument("--chunk_size",
-                    type=str,
-                    default="5, 10, 5",
-                    help="chunk")
-parser.add_argument("--chunk_interval",
-                    type=int,
-                    default=10,
-                    help="chunk")
-parser.add_argument("--audio_in",
-                    type=str,
-                    default=None,
-                    help="audio_in")
-parser.add_argument("--send_without_sleep",
-                    action="store_true",
-                    default=False,
-                    help="if audio_in is set, send_without_sleep")
-parser.add_argument("--test_thread_num",
-                    type=int,
-                    default=1,
-                    help="test_thread_num")
-parser.add_argument("--words_max_print",
-                    type=int,
-                    default=10000,
-                    help="chunk")
-parser.add_argument("--output_dir",
-                    type=str,
-                    default=None,
-                    help="output_dir")
-                    
-parser.add_argument("--ssl",
-                    type=int,
-                    default=1,
-                    help="1 for ssl connect, 0 for no ssl")
-
-args = parser.parse_args()
-args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
-print(args)
-# voices = asyncio.Queue()
-from queue import Queue
-voices = Queue()
-
-ibest_writer = None
-if args.output_dir is not None:
-    writer = DatadirWriter(args.output_dir)
-    ibest_writer = writer[f"1best_recog"]
-
-async def record_microphone():
-    is_finished = False
-    import pyaudio
-    #print("2")
-    global voices 
-    FORMAT = pyaudio.paInt16
-    CHANNELS = 1
-    RATE = 16000
-    chunk_size = 60*args.chunk_size[1]/args.chunk_interval
-    CHUNK = int(RATE / 1000 * chunk_size)
-
-    p = pyaudio.PyAudio()
-
-    stream = p.open(format=FORMAT,
-                    channels=CHANNELS,
-                    rate=RATE,
-                    input=True,
-                    frames_per_buffer=CHUNK)
-
-    message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
-    voices.put(message)
-    while True:
-
-        data = stream.read(CHUNK)
-        message = data  
-        
-        voices.put(message)
-
-        await asyncio.sleep(0.005)
-
-async def record_from_scp(chunk_begin,chunk_size):
-    import wave
-    global voices
-    is_finished = False
-    if args.audio_in.endswith(".scp"):
-        f_scp = open(args.audio_in)
-        wavs = f_scp.readlines()
-    else:
-        wavs = [args.audio_in]
-    if chunk_size>0:
-        wavs=wavs[chunk_begin:chunk_begin+chunk_size]
-    for wav in wavs:
-        wav_splits = wav.strip().split()
-        wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
-        wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
-        
-        # bytes_f = open(wav_path, "rb")
-        # bytes_data = bytes_f.read()
-        with wave.open(wav_path, "rb") as wav_file:
-            params = wav_file.getparams()
-            # header_length = wav_file.getheaders()[0][1]
-            # wav_file.setpos(header_length)
-            frames = wav_file.readframes(wav_file.getnframes())
-
-        audio_bytes = bytes(frames)
-        # stride = int(args.chunk_size/1000*16000*2)
-        stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
-        chunk_num = (len(audio_bytes)-1)//stride + 1
-        # print(stride)
-        
-        # send first time
-        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
-        voices.put(message)
-        is_speaking = True
-        for i in range(chunk_num):
-
-            beg = i*stride
-            data = audio_bytes[beg:beg+stride]
-            message = data  
-            voices.put(message)
-            if i == chunk_num-1:
-                is_speaking = False
-                message = json.dumps({"is_speaking": is_speaking})
-                voices.put(message)
-            # print("data_chunk: ", len(data_chunk))
-            # print(voices.qsize())
-            sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
-            await asyncio.sleep(sleep_duration)
-
-
-async def ws_send():
-    global voices
-    global websocket
-    print("started to sending data!")
-    while True:
-        while not voices.empty():
-            data = voices.get()
-            voices.task_done()
-            try:
-                await websocket.send(data)
-            except Exception as e:
-                print('Exception occurred:', e)
-                traceback.print_exc()
-                exit(0)
-            await asyncio.sleep(0.005)
-        await asyncio.sleep(0.005)
-
-
-
-async def message(id):
-    global websocket
-    text_print = ""
-    text_print_2pass_online = ""
-    text_print_2pass_offline = ""
-    while True:
-        try:
-            meg = await websocket.recv()
-            meg = json.loads(meg)
-            wav_name = meg.get("wav_name", "demo")
-            # print(wav_name)
-            text = meg["text"]
-            if ibest_writer is not None:
-                ibest_writer["text"][wav_name] = text
-            
-            if meg["mode"] == "online":
-                text_print += "{}".format(text)
-                text_print = text_print[-args.words_max_print:]
-                os.system('clear')
-                print("\rpid"+str(id)+": "+text_print)
-            elif meg["mode"] == "online":
-                text_print += "{}".format(text)
-                text_print = text_print[-args.words_max_print:]
-                os.system('clear')
-                print("\rpid"+str(id)+": "+text_print)
-            else:
-                if meg["mode"] == "2pass-online":
-                    text_print_2pass_online += "{}".format(text)
-                    text_print = text_print_2pass_offline + text_print_2pass_online
-                else:
-                    text_print_2pass_online = ""
-                    text_print = text_print_2pass_offline + "{}".format(text)
-                    text_print_2pass_offline += "{}".format(text)
-                text_print = text_print[-args.words_max_print:]
-                os.system('clear')
-                print("\rpid" + str(id) + ": " + text_print)
-
-        except Exception as e:
-            print("Exception:", e)
-            traceback.print_exc()
-            exit(0)
-
-async def print_messge():
-    global websocket
-    while True:
-        try:
-            meg = await websocket.recv()
-            meg = json.loads(meg)
-            print(meg)
-        except Exception as e:
-            print("Exception:", e)
-            traceback.print_exc()
-            exit(0)
-
-async def ws_client(id,chunk_begin,chunk_size):
-    global websocket
-    if  args.ssl==1:
-       ssl_context = ssl.SSLContext()
-       ssl_context.check_hostname = False
-       ssl_context.verify_mode = ssl.CERT_NONE
-       uri = "wss://{}:{}".format(args.host, args.port)
-    else:
-       uri = "ws://{}:{}".format(args.host, args.port)
-       ssl_context=None
-    print("connect to",uri)
-    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
-        if args.audio_in is not None:
-            task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
-        else:
-            task = asyncio.create_task(record_microphone())
-        task2 = asyncio.create_task(ws_send())
-        task3 = asyncio.create_task(message(id))
-        await asyncio.gather(task, task2, task3)
-
-def one_thread(id,chunk_begin,chunk_size):
-   asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
-   asyncio.get_event_loop().run_forever()
-
-
-if __name__ == '__main__':
-   # for microphone 
-   if  args.audio_in is  None:
-     p = Process(target=one_thread,args=(0, 0, 0))
-     p.start()
-     p.join()
-     print('end')
-   else:
-     # calculate the number of wavs for each preocess
-     if args.audio_in.endswith(".scp"):
-         f_scp = open(args.audio_in)
-         wavs = f_scp.readlines()
-     else:
-         wavs = [args.audio_in]
-     total_len=len(wavs)
-     if total_len>=args.test_thread_num:
-          chunk_size=int((total_len)/args.test_thread_num)
-          remain_wavs=total_len-chunk_size*args.test_thread_num
-     else:
-          chunk_size=1
-          remain_wavs=0
-
-     process_list = []
-     chunk_begin=0
-     for i in range(args.test_thread_num):
-         now_chunk_size= chunk_size
-         if remain_wavs>0:
-             now_chunk_size=chunk_size+1
-             remain_wavs=remain_wavs-1
-         # process i handle wavs at chunk_begin and size of now_chunk_size
-         p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
-         chunk_begin=chunk_begin+now_chunk_size
-         p.start()
-         process_list.append(p)
-
-     for i in process_list:
-         p.join()
-
-     print('end')
-
-
diff --git a/funasr/runtime/python/websocket/ws_server_2pass.py b/funasr/runtime/python/websocket/ws_server_2pass.py
deleted file mode 100644
index df13ad9..0000000
--- a/funasr/runtime/python/websocket/ws_server_2pass.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import asyncio
-import json
-import websockets
-import time
-import logging
-import tracemalloc
-import numpy as np
-import ssl
-from parse_args import args
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.logger import get_logger
-from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
-
-tracemalloc.start()
-
-logger = get_logger(log_level=logging.CRITICAL)
-logger.setLevel(logging.CRITICAL)
-
-
-websocket_users = set()
-
-print("model loading")
-# asr
-inference_pipeline_asr = pipeline(
-    task=Tasks.auto_speech_recognition,
-    model=args.asr_model,
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-    model_revision=None)
-
-
-# vad
-inference_pipeline_vad = pipeline(
-    task=Tasks.voice_activity_detection,
-    model=args.vad_model,
-    model_revision=None,
-    output_dir=None,
-    batch_size=1,
-    mode='online',
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-)
-
-if args.punc_model != "":
-    inference_pipeline_punc = pipeline(
-        task=Tasks.punctuation,
-        model=args.punc_model,
-        model_revision="v1.0.2",
-        ngpu=args.ngpu,
-        ncpu=args.ncpu,
-    )
-else:
-    inference_pipeline_punc = None
-
-inference_pipeline_asr_online = pipeline(
-    task=Tasks.auto_speech_recognition,
-    model=args.asr_model_online,
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-    model_revision='v1.0.4')
-
-print("model loaded")
-
-async def ws_serve(websocket, path):
-    frames = []
-    frames_asr = []
-    frames_asr_online = []
-    global websocket_users
-    websocket_users.add(websocket)
-    websocket.param_dict_asr = {}
-    websocket.param_dict_asr_online = {"cache": dict()}
-    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
-    websocket.param_dict_punc = {'cache': list()}
-    websocket.vad_pre_idx = 0
-    speech_start = False
-    speech_end_i = False
-    websocket.wav_name = "microphone"
-    print("new user connected", flush=True)
-
-    try:
-        async for message in websocket:
-            if isinstance(message, str):
-                messagejson = json.loads(message)
-        
-                if "is_speaking" in messagejson:
-                    websocket.is_speaking = messagejson["is_speaking"]
-                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
-                if "chunk_interval" in messagejson:
-                    websocket.chunk_interval = messagejson["chunk_interval"]
-                if "wav_name" in messagejson:
-                    websocket.wav_name = messagejson.get("wav_name")
-                if "chunk_size" in messagejson:
-                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
-            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
-                if not isinstance(message, str):
-                    frames.append(message)
-                    duration_ms = len(message)//32
-                    websocket.vad_pre_idx += duration_ms
-        
-                    # asr online
-                    frames_asr_online.append(message)
-                    websocket.param_dict_asr_online["is_final"] = speech_end_i
-                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
-                        
-                        audio_in = b"".join(frames_asr_online)
-                        await async_asr_online(websocket, audio_in)
-                        frames_asr_online = []
-                    if speech_start:
-                        frames_asr.append(message)
-                    # vad online
-                    speech_start_i, speech_end_i = await async_vad(websocket, message)
-                    if speech_start_i:
-                        speech_start = True
-                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
-                        frames_pre = frames[-beg_bias:]
-                        frames_asr = []
-                        frames_asr.extend(frames_pre)
-                # asr punc offline
-                if speech_end_i or not websocket.is_speaking:
-                    # print("vad end point")
-                    audio_in = b"".join(frames_asr)
-                    await async_asr(websocket, audio_in)
-                    frames_asr = []
-                    speech_start = False
-                    # frames_asr_online = []
-                    # websocket.param_dict_asr_online = {"cache": dict()}
-                    if not websocket.is_speaking:
-                        websocket.vad_pre_idx = 0
-                        frames = []
-                        websocket.param_dict_vad = {'in_cache': dict()}
-                    else:
-                        frames = frames[-20:]
-
-     
-    except websockets.ConnectionClosed:
-        print("ConnectionClosed...", websocket_users)
-        websocket_users.remove(websocket)
-    except websockets.InvalidState:
-        print("InvalidState...")
-    except Exception as e:
-        print("Exception:", e)
-
-
-async def async_vad(websocket, audio_in):
-
-    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
-
-    speech_start = False
-    speech_end = False
-    
-    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
-        return speech_start, speech_end
-    if segments_result["text"][0][0] != -1:
-        speech_start = segments_result["text"][0][0]
-    if segments_result["text"][0][1] != -1:
-        speech_end = True
-    return speech_start, speech_end
-
-
-async def async_asr(websocket, audio_in):
-            if len(audio_in) > 0:
-                # print(len(audio_in))
-                audio_in = load_bytes(audio_in)
-                
-                rec_result = inference_pipeline_asr(audio_in=audio_in,
-                                                    param_dict=websocket.param_dict_asr)
-                # print(rec_result)
-                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
-                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
-                                                         param_dict=websocket.param_dict_punc)
-                    # print("offline", rec_result)
-                if 'text' in rec_result:
-                    message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
-                    await websocket.send(message)
-
-
-async def async_asr_online(websocket, audio_in):
-    if len(audio_in) > 0:
-        audio_in = load_bytes(audio_in)
-        # print(websocket.param_dict_asr_online.get("is_final", False))
-        rec_result = inference_pipeline_asr_online(audio_in=audio_in,
-                                                   param_dict=websocket.param_dict_asr_online)
-        # print(rec_result)
-        if websocket.param_dict_asr_online.get("is_final", False):
-            return
-            #     websocket.param_dict_asr_online["cache"] = dict()
-        if "text" in rec_result:
-            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
-                # print("online", rec_result)
-                message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
-                await websocket.send(message)
-
-if len(args.certfile)>0:
-	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-	
-	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
-	ssl_cert = args.certfile
-	ssl_key = args.keyfile
-
-	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
-else:
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
-asyncio.get_event_loop().run_until_complete(start_server)
-asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_offline.py b/funasr/runtime/python/websocket/ws_server_offline.py
deleted file mode 100644
index 1ea1ff7..0000000
--- a/funasr/runtime/python/websocket/ws_server_offline.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import asyncio
-import json
-import websockets
-import time
-import logging
-import tracemalloc
-import numpy as np
-import ssl
-
-from parse_args import args
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.logger import get_logger
-from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
-
-tracemalloc.start()
-
-logger = get_logger(log_level=logging.CRITICAL)
-logger.setLevel(logging.CRITICAL)
-
-
-websocket_users = set()
-
-print("model loading")
-# asr
-inference_pipeline_asr = pipeline(
-    task=Tasks.auto_speech_recognition,
-    model=args.asr_model,
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-    model_revision=None)
-
-
-# vad
-inference_pipeline_vad = pipeline(
-    task=Tasks.voice_activity_detection,
-    model=args.vad_model,
-    model_revision=None,
-    output_dir=None,
-    batch_size=1,
-    mode='online',
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-)
-
-if args.punc_model != "":
-    inference_pipeline_punc = pipeline(
-        task=Tasks.punctuation,
-        model=args.punc_model,
-        model_revision=None,
-        ngpu=args.ngpu,
-        ncpu=args.ncpu,
-    )
-else:
-    inference_pipeline_punc = None
-
-print("model loaded")
-
-async def ws_serve(websocket, path):
-    frames = []
-    frames_asr = []
-    global websocket_users
-    websocket_users.add(websocket)
-    websocket.param_dict_asr = {}
-    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
-    websocket.param_dict_punc = {'cache': list()}
-    websocket.vad_pre_idx = 0
-    speech_start = False
-    websocket.wav_name = "microphone"
-    print("new user connected", flush=True)
-
-    try:
-        async for message in websocket:
-            if isinstance(message, str):
-                messagejson = json.loads(message)
-                if "is_speaking" in messagejson:
-                    websocket.is_speaking = messagejson["is_speaking"]
-                    websocket.param_dict_vad["is_final"] = not websocket.is_speaking
-                if "wav_name" in messagejson:
-                    websocket.wav_name = messagejson.get("wav_name")
-            
-            if len(frames_asr) > 0 or not isinstance(message, str):
-                if not isinstance(message, str):
-                    frames.append(message)
-                    duration_ms = len(message)//32
-                    websocket.vad_pre_idx += duration_ms
-    
-                    if speech_start:
-                        frames_asr.append(message)
-                    speech_start_i, speech_end_i = await async_vad(websocket, message)
-                    if speech_start_i:
-                        speech_start = True
-                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
-                        frames_pre = frames[-beg_bias:]
-                        frames_asr = []
-                        frames_asr.extend(frames_pre)
-                if speech_end_i or not websocket.is_speaking:
-                    audio_in = b"".join(frames_asr)
-                    await async_asr(websocket, audio_in)
-                    frames_asr = []
-                    speech_start = False
-                    if not websocket.is_speaking:
-                        websocket.vad_pre_idx = 0
-                        frames = []
-                        websocket.param_dict_vad = {'in_cache': dict()}
-                    else:
-                        frames = frames[-20:]
-
-     
-    except websockets.ConnectionClosed:
-        print("ConnectionClosed...", websocket_users)
-        websocket_users.remove(websocket)
-    except websockets.InvalidState:
-        print("InvalidState...")
-    except Exception as e:
-        print("Exception:", e)
-
-
-async def async_vad(websocket, audio_in):
-
-    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
-
-    speech_start = False
-    speech_end = False
-    
-    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
-        return speech_start, speech_end
-    if segments_result["text"][0][0] != -1:
-        speech_start = segments_result["text"][0][0]
-    if segments_result["text"][0][1] != -1:
-        speech_end = True
-    return speech_start, speech_end
-
-
-async def async_asr(websocket, audio_in):
-            if len(audio_in) > 0:
-                # print(len(audio_in))
-                audio_in = load_bytes(audio_in)
-                
-                rec_result = inference_pipeline_asr(audio_in=audio_in,
-                                                    param_dict=websocket.param_dict_asr)
-                print(rec_result)
-                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
-                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
-                                                         param_dict=websocket.param_dict_punc)
-                    # print(rec_result)
-                message = json.dumps({"mode": "offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
-                await websocket.send(message)
-                
-                
-if len(args.certfile)>0:
-	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-	
-	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
-	ssl_cert = args.certfile
-	ssl_key = args.keyfile
-
-	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
-else:
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
-asyncio.get_event_loop().run_until_complete(start_server)
-asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py
deleted file mode 100644
index 4cecd5f..0000000
--- a/funasr/runtime/python/websocket/ws_server_online.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import asyncio
-import json
-import websockets
-import time
-from queue import Queue
-import threading
-import logging
-import tracemalloc
-import numpy as np
-import ssl
-from parse_args import args
-from modelscope.pipelines import pipeline
-from modelscope.utils.constant import Tasks
-from modelscope.utils.logger import get_logger
-from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
-
-tracemalloc.start()
-
-logger = get_logger(log_level=logging.CRITICAL)
-logger.setLevel(logging.CRITICAL)
-
-
-websocket_users = set()
-
-
-print("model loading")
-
-inference_pipeline_asr_online = pipeline(
-	task=Tasks.auto_speech_recognition,
-	model=args.asr_model_online,
-	ngpu=args.ngpu,
-	ncpu=args.ncpu,
-	model_revision='v1.0.4')
-
-# vad
-inference_pipeline_vad = pipeline(
-    task=Tasks.voice_activity_detection,
-    model=args.vad_model,
-    model_revision=None,
-    output_dir=None,
-    batch_size=1,
-    mode='online',
-    ngpu=args.ngpu,
-    ncpu=1,
-)
-
-print("model loaded")
-
-
-
-async def ws_serve(websocket, path):
-	frames = []
-	frames_asr_online = []
-	global websocket_users
-	websocket_users.add(websocket)
-	websocket.param_dict_asr_online = {"cache": dict()}
-	websocket.param_dict_vad = {'in_cache': dict()}
-	websocket.wav_name = "microphone"
-	print("new user connected",flush=True)
-	try:
-		async for message in websocket:
-			
-			
-			if isinstance(message, str):
-				messagejson = json.loads(message)
-				
-				if "is_speaking" in messagejson:
-					websocket.is_speaking = messagejson["is_speaking"]
-					websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
-					websocket.param_dict_vad["is_final"] = not websocket.is_speaking
-					# need to fire engine manually if no data received any more
-					if not websocket.is_speaking:
-						await async_asr_online(websocket, b"")
-				if "chunk_interval" in messagejson:
-					websocket.chunk_interval=messagejson["chunk_interval"]
-				if "wav_name" in messagejson:
-					websocket.wav_name = messagejson.get("wav_name")
-				if "chunk_size" in messagejson:
-					websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
-			# if has bytes in buffer or message is bytes
-			if len(frames_asr_online) > 0 or not isinstance(message, str):
-				if not isinstance(message, str):
-					frames_asr_online.append(message)
-					# frames.append(message)
-					# duration_ms = len(message) // 32
-					# websocket.vad_pre_idx += duration_ms
-					speech_start_i, speech_end_i = await async_vad(websocket, message)
-					websocket.is_speaking = not speech_end_i
-					
-				if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
-					websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
-					audio_in = b"".join(frames_asr_online)
-					await async_asr_online(websocket, audio_in)
-					frames_asr_online = []
-	
-	
-	except websockets.ConnectionClosed:
-		print("ConnectionClosed...", websocket_users)
-		websocket_users.remove(websocket)
-	except websockets.InvalidState:
-		print("InvalidState...")
-	except Exception as e:
-		print("Exception:", e)
-
-
-async def async_asr_online(websocket,audio_in):
-	if len(audio_in) >= 0:
-		audio_in = load_bytes(audio_in)
-		# print(websocket.param_dict_asr_online.get("is_final", False))
-		rec_result = inference_pipeline_asr_online(audio_in=audio_in,
-		                                           param_dict=websocket.param_dict_asr_online)
-		# print(rec_result)
-		if websocket.param_dict_asr_online.get("is_final", False):
-			websocket.param_dict_asr_online["cache"] = dict()
-		if "text" in rec_result:
-			if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
-				message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
-				await websocket.send(message)
-
-
-async def async_vad(websocket, audio_in):
-	segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
-	
-	speech_start = False
-	speech_end = False
-	
-	if len(segments_result) == 0 or len(segments_result["text"]) > 1:
-		return speech_start, speech_end
-	if segments_result["text"][0][0] != -1:
-		speech_start = segments_result["text"][0][0]
-	if segments_result["text"][0][1] != -1:
-		speech_end = True
-	return speech_start, speech_end
-
-if len(args.certfile)>0:
-	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-	
-	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
-	ssl_cert = args.certfile
-	ssl_key = args.keyfile
-	
-	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
-else:
-	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
-asyncio.get_event_loop().run_until_complete(start_server)
-asyncio.get_event_loop().run_forever()
\ No newline at end of file

--
Gitblit v1.9.1