游雁
2023-05-24 f89111a45b73bc9eae6f161e0a5b8ee0464d58c3
websocket unify wss_srv for online offline 2pass
2个文件已添加
506 ■■■■■ 已修改文件
funasr/runtime/python/websocket/wss_client_asr.py 296 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/wss_srv_asr.py 210 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/wss_client_asr.py
New file
@@ -0,0 +1,296 @@
# -*- encoding: utf-8 -*-
import os
import time
import websockets,ssl
import asyncio
# import threading
import argparse
import json
import traceback
from multiprocessing import Process
from funasr.fileio.datadir_writer import DatadirWriter
import logging
logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="localhost",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=10095,
                    required=False,
                    help="grpc server port")
parser.add_argument("--chunk_size",
                    type=str,
                    default="5, 10, 5",
                    help="chunk")
parser.add_argument("--chunk_interval",
                    type=int,
                    default=10,
                    help="chunk")
parser.add_argument("--audio_in",
                    type=str,
                    default=None,
                    help="audio_in")
parser.add_argument("--send_without_sleep",
                    action="store_true",
                    default=False,
                    help="if audio_in is set, send_without_sleep")
parser.add_argument("--test_thread_num",
                    type=int,
                    default=1,
                    help="test_thread_num")
parser.add_argument("--words_max_print",
                    type=int,
                    default=10000,
                    help="chunk")
parser.add_argument("--output_dir",
                    type=str,
                    default=None,
                    help="output_dir")
parser.add_argument("--ssl",
                    type=int,
                    default=1,
                    help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--mode",
                    type=str,
                    default="2pass",
                    help="offline, online, 2pass")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
print(args)
# voices = asyncio.Queue()
from queue import Queue
voices = Queue()
ibest_writer = None
if args.output_dir is not None:
    writer = DatadirWriter(args.output_dir)
    ibest_writer = writer[f"1best_recog"]
async def record_microphone():
    is_finished = False
    import pyaudio
    #print("2")
    global voices
    FORMAT = pyaudio.paInt16
    CHANNELS = 1
    RATE = 16000
    chunk_size = 60*args.chunk_size[1]/args.chunk_interval
    CHUNK = int(RATE / 1000 * chunk_size)
    p = pyaudio.PyAudio()
    stream = p.open(format=FORMAT,
                    channels=CHANNELS,
                    rate=RATE,
                    input=True,
                    frames_per_buffer=CHUNK)
    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
    voices.put(message)
    while True:
        data = stream.read(CHUNK)
        message = data
        voices.put(message)
        await asyncio.sleep(0.005)
async def record_from_scp(chunk_begin,chunk_size):
    import wave
    global voices
    is_finished = False
    if args.audio_in.endswith(".scp"):
        f_scp = open(args.audio_in)
        wavs = f_scp.readlines()
    else:
        wavs = [args.audio_in]
    if chunk_size>0:
        wavs=wavs[chunk_begin:chunk_begin+chunk_size]
    for wav in wavs:
        wav_splits = wav.strip().split()
        wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
        wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
        # bytes_f = open(wav_path, "rb")
        # bytes_data = bytes_f.read()
        with wave.open(wav_path, "rb") as wav_file:
            params = wav_file.getparams()
            # header_length = wav_file.getheaders()[0][1]
            # wav_file.setpos(header_length)
            frames = wav_file.readframes(wav_file.getnframes())
        audio_bytes = bytes(frames)
        # stride = int(args.chunk_size/1000*16000*2)
        stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
        chunk_num = (len(audio_bytes)-1)//stride + 1
        # print(stride)
        # send first time
        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
        voices.put(message)
        is_speaking = True
        for i in range(chunk_num):
            beg = i*stride
            data = audio_bytes[beg:beg+stride]
            message = data
            voices.put(message)
            if i == chunk_num-1:
                is_speaking = False
                message = json.dumps({"is_speaking": is_speaking})
                voices.put(message)
            # print("data_chunk: ", len(data_chunk))
            # print(voices.qsize())
            sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
            await asyncio.sleep(sleep_duration)
async def ws_send():
    global voices
    global websocket
    print("started to sending data!")
    while True:
        while not voices.empty():
            data = voices.get()
            voices.task_done()
            try:
                await websocket.send(data)
            except Exception as e:
                print('Exception occurred:', e)
                traceback.print_exc()
                exit(0)
            await asyncio.sleep(0.005)
        await asyncio.sleep(0.005)
async def message(id):
    global websocket
    text_print = ""
    text_print_2pass_online = ""
    text_print_2pass_offline = ""
    while True:
        try:
            meg = await websocket.recv()
            meg = json.loads(meg)
            wav_name = meg.get("wav_name", "demo")
            # print(wav_name)
            text = meg["text"]
            if ibest_writer is not None:
                ibest_writer["text"][wav_name] = text
            if meg["mode"] == "online":
                text_print += "{}".format(text)
                text_print = text_print[-args.words_max_print:]
                os.system('clear')
                print("\rpid"+str(id)+": "+text_print)
            elif meg["mode"] == "online":
                text_print += "{}".format(text)
                text_print = text_print[-args.words_max_print:]
                os.system('clear')
                print("\rpid"+str(id)+": "+text_print)
            else:
                if meg["mode"] == "2pass-online":
                    text_print_2pass_online += "{}".format(text)
                    text_print = text_print_2pass_offline + text_print_2pass_online
                else:
                    text_print_2pass_online = ""
                    text_print = text_print_2pass_offline + "{}".format(text)
                    text_print_2pass_offline += "{}".format(text)
                text_print = text_print[-args.words_max_print:]
                os.system('clear')
                print("\rpid" + str(id) + ": " + text_print)
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
            exit(0)
async def print_messge():
    global websocket
    while True:
        try:
            meg = await websocket.recv()
            meg = json.loads(meg)
            print(meg)
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
            exit(0)
async def ws_client(id,chunk_begin,chunk_size):
    global websocket
    if  args.ssl==1:
       ssl_context = ssl.SSLContext()
       ssl_context.check_hostname = False
       ssl_context.verify_mode = ssl.CERT_NONE
       uri = "wss://{}:{}".format(args.host, args.port)
    else:
       uri = "ws://{}:{}".format(args.host, args.port)
       ssl_context=None
    print("connect to",uri)
    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
        if args.audio_in is not None:
            task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
        else:
            task = asyncio.create_task(record_microphone())
        task2 = asyncio.create_task(ws_send())
        task3 = asyncio.create_task(message(id))
        await asyncio.gather(task, task2, task3)
def one_thread(id,chunk_begin,chunk_size):
   asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
   asyncio.get_event_loop().run_forever()
if __name__ == '__main__':
   # for microphone
   if  args.audio_in is  None:
     p = Process(target=one_thread,args=(0, 0, 0))
     p.start()
     p.join()
     print('end')
   else:
     # calculate the number of wavs for each preocess
     if args.audio_in.endswith(".scp"):
         f_scp = open(args.audio_in)
         wavs = f_scp.readlines()
     else:
         wavs = [args.audio_in]
     total_len=len(wavs)
     if total_len>=args.test_thread_num:
          chunk_size=int((total_len)/args.test_thread_num)
          remain_wavs=total_len-chunk_size*args.test_thread_num
     else:
          chunk_size=1
          remain_wavs=0
     process_list = []
     chunk_begin=0
     for i in range(args.test_thread_num):
         now_chunk_size= chunk_size
         if remain_wavs>0:
             now_chunk_size=chunk_size+1
             remain_wavs=remain_wavs-1
         # process i handle wavs at chunk_begin and size of now_chunk_size
         p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
         chunk_begin=chunk_begin+now_chunk_size
         p.start()
         process_list.append(p)
     for i in process_list:
         p.join()
     print('end')
funasr/runtime/python/websocket/wss_srv_asr.py
New file
@@ -0,0 +1,210 @@
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import ssl
from parse_args import args
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
tracemalloc.start()
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
websocket_users = set()
print("model loading")
# asr
inference_pipeline_asr = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision=None)
# vad
inference_pipeline_vad = pipeline(
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision=None,
    output_dir=None,
    batch_size=1,
    mode='online',
    ngpu=args.ngpu,
    ncpu=args.ncpu,
)
if args.punc_model != "":
    inference_pipeline_punc = pipeline(
        task=Tasks.punctuation,
        model=args.punc_model,
        model_revision="v1.0.2",
        ngpu=args.ngpu,
        ncpu=args.ncpu,
    )
else:
    inference_pipeline_punc = None
inference_pipeline_asr_online = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model_online,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision='v1.0.4')
print("model loaded")
async def ws_serve(websocket, path):
    frames = []
    frames_asr = []
    frames_asr_online = []
    global websocket_users
    websocket_users.add(websocket)
    websocket.param_dict_asr = {}
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
    speech_end_i = False
    websocket.wav_name = "microphone"
    websocket.mode = "2pass"
    print("new user connected", flush=True)
    try:
        async for message in websocket:
            if isinstance(message, str):
                messagejson = json.loads(message)
                if "is_speaking" in messagejson:
                    websocket.is_speaking = messagejson["is_speaking"]
                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
                if "chunk_interval" in messagejson:
                    websocket.chunk_interval = messagejson["chunk_interval"]
                if "wav_name" in messagejson:
                    websocket.wav_name = messagejson.get("wav_name")
                if "chunk_size" in messagejson:
                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
                if "mode" in messagejson:
                    websocket.mode = messagejson["mode"]
            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
                if not isinstance(message, str):
                    frames.append(message)
                    duration_ms = len(message)//32
                    websocket.vad_pre_idx += duration_ms
                    # asr online
                    frames_asr_online.append(message)
                    websocket.param_dict_asr_online["is_final"] = speech_end_i
                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
                        if websocket.mode == "2pass" or websocket.mode == "online":
                            audio_in = b"".join(frames_asr_online)
                            await async_asr_online(websocket, audio_in)
                        frames_asr_online = []
                    if speech_start:
                        frames_asr.append(message)
                    # vad online
                    speech_start_i, speech_end_i = await async_vad(websocket, message)
                    if speech_start_i:
                        speech_start = True
                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                        frames_pre = frames[-beg_bias:]
                        frames_asr = []
                        frames_asr.extend(frames_pre)
                # asr punc offline
                if speech_end_i or not websocket.is_speaking:
                    # print("vad end point")
                    if websocket.mode == "2pass" or websocket.mode == "offline":
                        audio_in = b"".join(frames_asr)
                        await async_asr(websocket, audio_in)
                    frames_asr = []
                    speech_start = False
                    # frames_asr_online = []
                    # websocket.param_dict_asr_online = {"cache": dict()}
                    if not websocket.is_speaking:
                        websocket.vad_pre_idx = 0
                        frames = []
                        websocket.param_dict_vad = {'in_cache': dict()}
                    else:
                        frames = frames[-20:]
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users)
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")
    except Exception as e:
        print("Exception:", e)
async def async_vad(websocket, audio_in):
    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    speech_start = False
    speech_end = False
    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
        return speech_start, speech_end
    if segments_result["text"][0][0] != -1:
        speech_start = segments_result["text"][0][0]
    if segments_result["text"][0][1] != -1:
        speech_end = True
    return speech_start, speech_end
async def async_asr(websocket, audio_in):
            if len(audio_in) > 0:
                # print(len(audio_in))
                audio_in = load_bytes(audio_in)
                rec_result = inference_pipeline_asr(audio_in=audio_in,
                                                    param_dict=websocket.param_dict_asr)
                # print(rec_result)
                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
                                                         param_dict=websocket.param_dict_punc)
                    # print("offline", rec_result)
                if 'text' in rec_result:
                    message = json.dumps({"mode": "2pass-offline", "text": rec_result["text"], "wav_name": websocket.wav_name})
                    await websocket.send(message)
async def async_asr_online(websocket, audio_in):
    if len(audio_in) > 0:
        audio_in = load_bytes(audio_in)
        # print(websocket.param_dict_asr_online.get("is_final", False))
        rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                   param_dict=websocket.param_dict_asr_online)
        # print(rec_result)
        if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
            return
            #     websocket.param_dict_asr_online["cache"] = dict()
        if "text" in rec_result:
            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                # print("online", rec_result)
                message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
                await websocket.send(message)
if len(args.certfile)>0:
    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
    ssl_cert = args.certfile
    ssl_key = args.keyfile
    ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
else:
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()