From f89111a45b73bc9eae6f161e0a5b8ee0464d58c3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 24 五月 2023 23:30:26 +0800
Subject: [PATCH] websocket unify wss_srv for online offline 2pass

---
 funasr/runtime/python/websocket/wss_client_asr.py |  296 ++++++++++++++++++++++++++++++++
 funasr/runtime/python/websocket/wss_srv_asr.py    |  210 +++++++++++++++++++++++
 2 files changed, 506 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/python/websocket/wss_client_asr.py b/funasr/runtime/python/websocket/wss_client_asr.py
new file mode 100644
index 0000000..586e0a4
--- /dev/null
+++ b/funasr/runtime/python/websocket/wss_client_asr.py
@@ -0,0 +1,296 @@
+# -*- encoding: utf-8 -*-
+import os
+import time
+import websockets,ssl
+import asyncio
+# import threading
+import argparse
+import json
+import traceback
+from multiprocessing import Process
+from funasr.fileio.datadir_writer import DatadirWriter
+
+import logging
+
+logging.basicConfig(level=logging.ERROR)
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--host",
+                    type=str,
+                    default="localhost",
+                    required=False,
+                    help="host ip, localhost, 0.0.0.0")
+parser.add_argument("--port",
+                    type=int,
+                    default=10095,
+                    required=False,
+                    help="grpc server port")
+parser.add_argument("--chunk_size",
+                    type=str,
+                    default="5, 10, 5",
+                    help="chunk")
+parser.add_argument("--chunk_interval",
+                    type=int,
+                    default=10,
+                    help="chunk")
+parser.add_argument("--audio_in",
+                    type=str,
+                    default=None,
+                    help="audio_in")
+parser.add_argument("--send_without_sleep",
+                    action="store_true",
+                    default=False,
+                    help="if audio_in is set, send_without_sleep")
+parser.add_argument("--test_thread_num",
+                    type=int,
+                    default=1,
+                    help="test_thread_num")
+parser.add_argument("--words_max_print",
+                    type=int,
+                    default=10000,
+                    help="chunk")
+parser.add_argument("--output_dir",
+                    type=str,
+                    default=None,
+                    help="output_dir")
+                    
+parser.add_argument("--ssl",
+                    type=int,
+                    default=1,
+                    help="1 for ssl connect, 0 for no ssl")
+parser.add_argument("--mode",
+                    type=str,
+                    default="2pass",
+                    help="offline, online, 2pass")
+
+args = parser.parse_args()
+args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
+print(args)
+# voices = asyncio.Queue()
+from queue import Queue
+voices = Queue()
+
+ibest_writer = None
+if args.output_dir is not None:
+    writer = DatadirWriter(args.output_dir)
+    ibest_writer = writer[f"1best_recog"]
+
+async def record_microphone():
+    is_finished = False
+    import pyaudio
+    #print("2")
+    global voices 
+    FORMAT = pyaudio.paInt16
+    CHANNELS = 1
+    RATE = 16000
+    chunk_size = 60*args.chunk_size[1]/args.chunk_interval
+    CHUNK = int(RATE / 1000 * chunk_size)
+
+    p = pyaudio.PyAudio()
+
+    stream = p.open(format=FORMAT,
+                    channels=CHANNELS,
+                    rate=RATE,
+                    input=True,
+                    frames_per_buffer=CHUNK)
+
+    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
+    voices.put(message)
+    while True:
+
+        data = stream.read(CHUNK)
+        message = data  
+        
+        voices.put(message)
+
+        await asyncio.sleep(0.005)
+
+async def record_from_scp(chunk_begin,chunk_size):
+    import wave
+    global voices
+    is_finished = False
+    if args.audio_in.endswith(".scp"):
+        f_scp = open(args.audio_in)
+        wavs = f_scp.readlines()
+    else:
+        wavs = [args.audio_in]
+    if chunk_size>0:
+        wavs=wavs[chunk_begin:chunk_begin+chunk_size]
+    for wav in wavs:
+        wav_splits = wav.strip().split()
+        wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
+        wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
+        
+        # bytes_f = open(wav_path, "rb")
+        # bytes_data = bytes_f.read()
+        with wave.open(wav_path, "rb") as wav_file:
+            params = wav_file.getparams()
+            # header_length = wav_file.getheaders()[0][1]
+            # wav_file.setpos(header_length)
+            frames = wav_file.readframes(wav_file.getnframes())
+
+        audio_bytes = bytes(frames)
+        # stride = int(args.chunk_size/1000*16000*2)
+        stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
+        chunk_num = (len(audio_bytes)-1)//stride + 1
+        # print(stride)
+        
+        # send first time
+        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+        voices.put(message)
+        is_speaking = True
+        for i in range(chunk_num):
+
+            beg = i*stride
+            data = audio_bytes[beg:beg+stride]
+            message = data  
+            voices.put(message)
+            if i == chunk_num-1:
+                is_speaking = False
+                message = json.dumps({"is_speaking": is_speaking})
+                voices.put(message)
+            # print("data_chunk: ", len(data_chunk))
+            # print(voices.qsize())
+            sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
+            await asyncio.sleep(sleep_duration)
+
+
+async def ws_send():
+    global voices
+    global websocket
+    print("started to sending data!")
+    while True:
+        while not voices.empty():
+            data = voices.get()
+            voices.task_done()
+            try:
+                await websocket.send(data)
+            except Exception as e:
+                print('Exception occurred:', e)
+                traceback.print_exc()
+                exit(0)
+            await asyncio.sleep(0.005)
+        await asyncio.sleep(0.005)
+
+
+
+async def message(id):
+    global websocket
+    text_print = ""
+    text_print_2pass_online = ""
+    text_print_2pass_offline = ""
+    while True:
+        try:
+            meg = await websocket.recv()
+            meg = json.loads(meg)
+            wav_name = meg.get("wav_name", "demo")
+            # print(wav_name)
+            text = meg["text"]
+            if ibest_writer is not None:
+                ibest_writer["text"][wav_name] = text
+            
+            if meg["mode"] == "online":
+                text_print += "{}".format(text)
+                text_print = text_print[-args.words_max_print:]
+                os.system('clear')
+                print("\rpid"+str(id)+": "+text_print)
+            elif meg["mode"] == "online":
+                text_print += "{}".format(text)
+                text_print = text_print[-args.words_max_print:]
+                os.system('clear')
+                print("\rpid"+str(id)+": "+text_print)
+            else:
+                if meg["mode"] == "2pass-online":
+                    text_print_2pass_online += "{}".format(text)
+                    text_print = text_print_2pass_offline + text_print_2pass_online
+                else:
+                    text_print_2pass_online = ""
+                    text_print = text_print_2pass_offline + "{}".format(text)
+                    text_print_2pass_offline += "{}".format(text)
+                text_print = text_print[-args.words_max_print:]
+                os.system('clear')
+                print("\rpid" + str(id) + ": " + text_print)
+
+        except Exception as e:
+            print("Exception:", e)
+            traceback.print_exc()
+            exit(0)
+
+async def print_messge():
+    global websocket
+    while True:
+        try:
+            meg = await websocket.recv()
+            meg = json.loads(meg)
+            print(meg)
+        except Exception as e:
+            print("Exception:", e)
+            traceback.print_exc()
+            exit(0)
+
+async def ws_client(id,chunk_begin,chunk_size):
+    global websocket
+    if  args.ssl==1:
+       ssl_context = ssl.SSLContext()
+       ssl_context.check_hostname = False
+       ssl_context.verify_mode = ssl.CERT_NONE
+       uri = "wss://{}:{}".format(args.host, args.port)
+    else:
+       uri = "ws://{}:{}".format(args.host, args.port)
+       ssl_context=None
+    print("connect to",uri)
+    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
+        if args.audio_in is not None:
+            task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
+        else:
+            task = asyncio.create_task(record_microphone())
+        task2 = asyncio.create_task(ws_send())
+        task3 = asyncio.create_task(message(id))
+        await asyncio.gather(task, task2, task3)
+
+def one_thread(id,chunk_begin,chunk_size):
+   asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
+   asyncio.get_event_loop().run_forever()
+
+
+if __name__ == '__main__':
+   # for microphone 
+   if  args.audio_in is  None:
+     p = Process(target=one_thread,args=(0, 0, 0))
+     p.start()
+     p.join()
+     print('end')
+   else:
+     # calculate the number of wavs for each preocess
+     if args.audio_in.endswith(".scp"):
+         f_scp = open(args.audio_in)
+         wavs = f_scp.readlines()
+     else:
+         wavs = [args.audio_in]
+     total_len=len(wavs)
+     if total_len>=args.test_thread_num:
+          chunk_size=int((total_len)/args.test_thread_num)
+          remain_wavs=total_len-chunk_size*args.test_thread_num
+     else:
+          chunk_size=1
+          remain_wavs=0
+
+     process_list = []
+     chunk_begin=0
+     for i in range(args.test_thread_num):
+         now_chunk_size= chunk_size
+         if remain_wavs>0:
+             now_chunk_size=chunk_size+1
+             remain_wavs=remain_wavs-1
+         # process i handle wavs at chunk_begin and size of now_chunk_size
+         p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
+         chunk_begin=chunk_begin+now_chunk_size
+         p.start()
+         process_list.append(p)
+
+     for i in process_list:
+         p.join()
+
+     print('end')
+
+
diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
new file mode 100644
index 0000000..71c97e6
--- /dev/null
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -0,0 +1,210 @@
+import asyncio
+import json
+import websockets
+import time
+import logging
+import tracemalloc
+import numpy as np
+import ssl
+from parse_args import args
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
+
+tracemalloc.start()
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+
+websocket_users = set()
+
+print("model loading")
+# asr
+inference_pipeline_asr = pipeline(
+    task=Tasks.auto_speech_recognition,
+    model=args.asr_model,
+    ngpu=args.ngpu,
+    ncpu=args.ncpu,
+    model_revision=None)
+
+
+# vad
+inference_pipeline_vad = pipeline(
+    task=Tasks.voice_activity_detection,
+    model=args.vad_model,
+    model_revision=None,
+    output_dir=None,
+    batch_size=1,
+    mode='online',
+    ngpu=args.ngpu,
+    ncpu=args.ncpu,
+)
+
+if args.punc_model != "":
+    inference_pipeline_punc = pipeline(
+        task=Tasks.punctuation,
+        model=args.punc_model,
+        model_revision="v1.0.2",
+        ngpu=args.ngpu,
+        ncpu=args.ncpu,
+    )
+else:
+    inference_pipeline_punc = None
+
+inference_pipeline_asr_online = pipeline(
+    task=Tasks.auto_speech_recognition,
+    model=args.asr_model_online,
+    ngpu=args.ngpu,
+    ncpu=args.ncpu,
+    model_revision='v1.0.4')
+
+print("model loaded")
+
+async def ws_serve(websocket, path):
+    frames = []
+    frames_asr = []
+    frames_asr_online = []
+    global websocket_users
+    websocket_users.add(websocket)
+    websocket.param_dict_asr = {}
+    websocket.param_dict_asr_online = {"cache": dict()}
+    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
+    websocket.param_dict_punc = {'cache': list()}
+    websocket.vad_pre_idx = 0
+    speech_start = False
+    speech_end_i = False
+    websocket.wav_name = "microphone"
+    websocket.mode = "2pass"
+    print("new user connected", flush=True)
+
+    try:
+        async for message in websocket:
+            if isinstance(message, str):
+                messagejson = json.loads(message)
+        
+                if "is_speaking" in messagejson:
+                    websocket.is_speaking = messagejson["is_speaking"]
+                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+                if "chunk_interval" in messagejson:
+                    websocket.chunk_interval = messagejson["chunk_interval"]
+                if "wav_name" in messagejson:
+                    websocket.wav_name = messagejson.get("wav_name")
+                if "chunk_size" in messagejson:
+                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+                if "mode" in messagejson:
+                    websocket.mode = messagejson["mode"]
+            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
+                if not isinstance(message, str):
+                    frames.append(message)
+                    duration_ms = len(message)//32
+                    websocket.vad_pre_idx += duration_ms
+        
+                    # asr online
+                    frames_asr_online.append(message)
+                    websocket.param_dict_asr_online["is_final"] = speech_end_i
+                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
+                        if websocket.mode == "2pass" or websocket.mode == "online":
+                            audio_in = b"".join(frames_asr_online)
+                            await async_asr_online(websocket, audio_in)
+                        frames_asr_online = []
+                    if speech_start:
+                        frames_asr.append(message)
+                    # vad online
+                    speech_start_i, speech_end_i = await async_vad(websocket, message)
+                    if speech_start_i:
+                        speech_start = True
+                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+                        frames_pre = frames[-beg_bias:]
+                        frames_asr = []
+                        frames_asr.extend(frames_pre)
+                # asr punc offline
+                if speech_end_i or not websocket.is_speaking:
+                    # print("vad end point")
+                    if websocket.mode == "2pass" or websocket.mode == "offline":
+                        audio_in = b"".join(frames_asr)
+                        await async_asr(websocket, audio_in)
+                    frames_asr = []
+                    speech_start = False
+                    # frames_asr_online = []
+                    # websocket.param_dict_asr_online = {"cache": dict()}
+                    if not websocket.is_speaking:
+                        websocket.vad_pre_idx = 0
+                        frames = []
+                        websocket.param_dict_vad = {'in_cache': dict()}
+                    else:
+                        frames = frames[-20:]
+
+     
+    except websockets.ConnectionClosed:
+        print("ConnectionClosed...", websocket_users)
+        websocket_users.remove(websocket)
+    except websockets.InvalidState:
+        print("InvalidState...")
+    except Exception as e:
+        print("Exception:", e)
+
+
+async def async_vad(websocket, audio_in):
+
+    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+
+    speech_start = False
+    speech_end = False
+    
+    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
+        return speech_start, speech_end
+    if segments_result["text"][0][0] != -1:
+        speech_start = segments_result["text"][0][0]
+    if segments_result["text"][0][1] != -1:
+        speech_end = True
+    return speech_start, speech_end
+
+
+async def async_asr(websocket, audio_in):
+            if len(audio_in) > 0:
+                # print(len(audio_in))
+                audio_in = load_bytes(audio_in)
+                
+                rec_result = inference_pipeline_asr(audio_in=audio_in,
+                                                    param_dict=websocket.param_dict_asr)
+                # print(rec_result)
+                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
+                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
+                                                         param_dict=websocket.param_dict_punc)
+                    # print("offline", rec_result)
+                if 'text' in rec_result:
+                    message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
+                    await websocket.send(message)
+
+
+async def async_asr_online(websocket, audio_in):
+    if len(audio_in) > 0:
+        audio_in = load_bytes(audio_in)
+        # print(websocket.param_dict_asr_online.get("is_final", False))
+        rec_result = inference_pipeline_asr_online(audio_in=audio_in,
+                                                   param_dict=websocket.param_dict_asr_online)
+        # print(rec_result)
+        if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
+            return
+            #     websocket.param_dict_asr_online["cache"] = dict()
+        if "text" in rec_result:
+            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
+                # print("online", rec_result)
+                message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+                await websocket.send(message)
+
+if len(args.certfile)>0:
+    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+    
+    # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+    ssl_cert = args.certfile
+    ssl_key = args.keyfile
+    
+    ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+asyncio.get_event_loop().run_until_complete(start_server)
+asyncio.get_event_loop().run_forever()
\ No newline at end of file

--
Gitblit v1.9.1