游雁
2023-04-27 ba8d73d57db031fa7a1265d2c837ff694d5c5c93
websocket
1个文件已修改
2个文件已添加
1 文件已重命名
4个文件已删除
1053 ■■■■ 已修改文件
.gitignore 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ASR_server.py 187 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ASR_server_2pass.py 252 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ASR_server_streaming.py 261 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ASR_server_streaming_asr.py 161 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/parse_args.py 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_client.py 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_server_online.py 108 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
.gitignore
@@ -17,3 +17,4 @@
dist
build
funasr.egg-info
sherpa
funasr/runtime/python/websocket/ASR_server.py
File was deleted
funasr/runtime/python/websocket/ASR_server_2pass.py
File was deleted
funasr/runtime/python/websocket/ASR_server_streaming.py
File was deleted
funasr/runtime/python/websocket/ASR_server_streaming_asr.py
File was deleted
funasr/runtime/python/websocket/parse_args.py
New file
@@ -0,0 +1,35 @@
# -*- encoding: utf-8 -*-
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="0.0.0.0",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=10095,
                    required=False,
                    help="grpc server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="model from modelscope")
parser.add_argument("--asr_model_online",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
                    help="model from modelscope")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="model from modelscope")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                    help="model from modelscope")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
args = parser.parse_args()
funasr/runtime/python/websocket/ws_client.py
File was renamed from funasr/runtime/python/websocket/ASR_client.py
@@ -1,4 +1,5 @@
# -*- encoding: utf-8 -*-
import os
import time
import websockets
import asyncio
@@ -18,15 +19,20 @@
                    required=False,
                    help="grpc server port")
parser.add_argument("--chunk_size",
                    type=str,
                    default="5, 10, 5",
                    help="chunk")
parser.add_argument("--chunk_interval",
                    type=int,
                    default=300,
                    help="ms")
                    default=10,
                    help="chunk")
parser.add_argument("--audio_in",
                    type=str,
                    default=None,
                    help="audio_in")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
# voices = asyncio.Queue()
from queue import Queue
@@ -34,13 +40,15 @@
    
# 其他函数可以通过调用send(data)来发送数据,例如:
async def record_microphone():
    is_finished = False
    import pyaudio
    #print("2")
    global voices 
    FORMAT = pyaudio.paInt16
    CHANNELS = 1
    RATE = 16000
    CHUNK = int(RATE / 1000 * args.chunk_size)
    chunk_size = 60*args.chunk_size[1]/args.chunk_interval
    CHUNK = int(RATE / 1000 * chunk_size)
    p = pyaudio.PyAudio()
@@ -54,7 +62,7 @@
        data = stream.read(CHUNK)
        data = data.decode('ISO-8859-1')
        message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data})
        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
        
        voices.put(message)
        #print(voices.qsize())
@@ -65,6 +73,7 @@
async def record_from_scp():
    import wave
    global voices
    is_finished = False
    if args.audio_in.endswith(".scp"):
        f_scp = open(args.audio_in)
        wavs = f_scp.readlines()
@@ -86,9 +95,10 @@
        # 将音频帧数据转换为字节类型的数据
        audio_bytes = bytes(frames)
        stride = int(args.chunk_size/1000*16000*2)
        # stride = int(args.chunk_size/1000*16000*2)
        stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
        chunk_num = (len(audio_bytes)-1)//stride + 1
        print(stride)
        # print(stride)
        is_speaking = True
        for i in range(chunk_num):
            if i == chunk_num-1:
@@ -96,13 +106,16 @@
            beg = i*stride
            data = audio_bytes[beg:beg+stride]
            data = data.decode('ISO-8859-1')
            message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data})
            message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished})
            voices.put(message)
            # print("data_chunk: ", len(data_chunk))
            # print(voices.qsize())
        
            await asyncio.sleep(args.chunk_size/1000)
            await asyncio.sleep(60*args.chunk_size[1]/args.chunk_interval/1000)
     
    is_finished = True
    message = json.dumps({"is_finished": is_finished})
    voices.put(message)
async def ws_send():
    global voices
@@ -123,6 +136,24 @@
async def message():
    global websocket
    text_print = ""
    while True:
        try:
            meg = await websocket.recv()
            meg = json.loads(meg)
            # print(meg, end = '')
            # print("\r")
            text = meg["text"][0]
            text_print += text
            text_print = text_print[-55:]
            os.system('clear')
            print("\r"+text_print)
        except Exception as e:
            print("Exception:", e)
async def print_messge():
    global websocket
    while True:
        try:
            meg = await websocket.recv()
@@ -130,7 +161,6 @@
            print(meg)
        except Exception as e:
            print("Exception:", e)          
async def ws_client():
funasr/runtime/python/websocket/ws_server_online.py
New file
@@ -0,0 +1,108 @@
import asyncio
import json
import websockets
import time
from queue import Queue
import threading
import logging
import tracemalloc
import numpy as np
from parse_args import args
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from funasr_onnx.utils.frontend import load_bytes
tracemalloc.start()
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
websocket_users = set()
print("model loading")
inference_pipeline_asr_online = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model_online,
    model_revision='v1.0.4')
print("model loaded")
async def ws_serve(websocket, path):
    frames_online = []
    global websocket_users
    websocket.send_msg = Queue()
    websocket_users.add(websocket)
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.speek_online = Queue()
    ss_online = threading.Thread(target=asr_online, args=(websocket,))
    ss_online.start()
    try:
        async for message in websocket:
            message = json.loads(message)
            is_finished = message["is_finished"]
            if not is_finished:
                audio = bytes(message['audio'], 'ISO-8859-1')
                is_speaking = message["is_speaking"]
                websocket.param_dict_asr_online["is_final"] = not is_speaking
                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
                frames_online.append(audio)
                if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking:
                    audio_in = b"".join(frames_online)
                    websocket.speek_online.put(audio_in)
                    frames_online = []
            if not websocket.send_msg.empty():
                await websocket.send(websocket.send_msg.get())
                websocket.send_msg.task_done()
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users)    # 链接断开
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")    # 无效状态
    except Exception as e:
        print("Exception:", e)
def asr_online(websocket):  # ASR推理
    global websocket_users
    while websocket in websocket_users:
        if not websocket.speek_online.empty():
            audio_in = websocket.speek_online.get()
            websocket.speek_online.task_done()
            if len(audio_in) > 0:
                # print(len(audio_in))
                audio_in = load_bytes(audio_in)
                rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                           param_dict=websocket.param_dict_asr_online)
                if websocket.param_dict_asr_online["is_final"]:
                    websocket.param_dict_asr_online["cache"] = dict()
                if "text" in rec_result:
                    if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                        print(rec_result["text"])
                        message = json.dumps({"mode": "online", "text": rec_result["text"]})
                        websocket.send_msg.put(message)
        time.sleep(0.005)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()