zhifu gao
2023-05-06 b639e136b4a37400e569fbd0e64dc5aed066ce33
Merge pull request #466 from zhaomingwork/fix_bug_for_python_websocket

fix bug for python websocket
2个文件已修改
78 ■■■■■ 已修改文件
funasr/runtime/python/websocket/ws_client.py 48 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_server_online.py 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_client.py
@@ -6,7 +6,8 @@
# import threading
import argparse
import json
import traceback
from multiprocessing import  Process
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
@@ -30,6 +31,11 @@
                    type=str,
                    default=None,
                    help="audio_in")
parser.add_argument("--test_thread_num",
                    type=int,
                    default=1,
                    help="test_thread_num")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
@@ -129,12 +135,14 @@
                await websocket.send(data) # 通过ws对象发送数据
            except Exception as e:
                print('Exception occurred:', e)
                traceback.print_exc()
                exit(0)
            await asyncio.sleep(0.005)
        await asyncio.sleep(0.005)
async def message():
async def message(id):
    global websocket
    text_print = ""
    while True:
@@ -143,14 +151,14 @@
            meg = json.loads(meg)
            # print(meg, end = '')
            # print("\r")
            text = meg["text"][0]
            text_print += text
            text_print += " {}".format(meg["text"][0])
            text_print = text_print[-55:]
            os.system('clear')
            print("\r"+text_print)
            #os.system('clear')
            print("\r"+str(id)+":"+text_print)
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
            exit(0)
async def print_messge():
    global websocket
@@ -161,9 +169,10 @@
            print(meg)
        except Exception as e:
            print("Exception:", e)
            traceback.print_exc()
            exit(0)
async def ws_client():
async def ws_client(id):
    global websocket # 定义一个全局变量ws,用于保存websocket连接对象
    # uri = "ws://11.167.134.197:8899"
    uri = "ws://{}:{}".format(args.host, args.port)
@@ -174,9 +183,24 @@
        else:
            task = asyncio.create_task(record_microphone())  # 创建一个后台任务录音
        task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送
        task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务
        task3 = asyncio.create_task(message(id)) # 创建一个后台接收消息的任务
        await asyncio.gather(task, task2, task3)
def one_thread(id):
   asyncio.get_event_loop().run_until_complete(ws_client(id)) # 启动协程
   asyncio.get_event_loop().run_forever()
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
asyncio.get_event_loop().run_forever()
if __name__ == '__main__':
    process_list = []
    for i in range(args.test_thread_num):
        p = Process(target=one_thread,args=(i,)) #实例化进程对象
        p.start()
        process_list.append(p)
    for i in process_list:
        p.join()
    print('结束测试')
funasr/runtime/python/websocket/ws_server_online.py
@@ -41,8 +41,6 @@
    websocket_users.add(websocket)
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.speek_online = Queue()
    ss_online = threading.Thread(target=asr_online, args=(websocket,))
    ss_online.start()
    try:
        async for message in websocket:
@@ -56,18 +54,12 @@
                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
                
                frames_online.append(audio)
                if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking:
                    audio_in = b"".join(frames_online)
                    websocket.speek_online.put(audio_in)
                    await async_asr_online(websocket,audio_in)
                    frames_online = []
            if not websocket.send_msg.empty():
                await websocket.send(websocket.send_msg.get())
                websocket.send_msg.task_done()
     
    except websockets.ConnectionClosed:
@@ -78,29 +70,21 @@
    except Exception as e:
        print("Exception:", e)
 
def asr_online(websocket):  # ASR推理
    global websocket_users
    while websocket in websocket_users:
        if not websocket.speek_online.empty():
            audio_in = websocket.speek_online.get()
            websocket.speek_online.task_done()
async def async_asr_online(websocket,audio_in): # ASR推理
            if len(audio_in) > 0:
                # print(len(audio_in))
                audio_in = load_bytes(audio_in)
                rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                           param_dict=websocket.param_dict_asr_online)
                if websocket.param_dict_asr_online["is_final"]:
                    websocket.param_dict_asr_online["cache"] = dict()
                if "text" in rec_result:
                    if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                        print(rec_result["text"])
                        if len(rec_result["text"])>0:
                            rec_result["text"][0]=rec_result["text"][0].replace(" ","")
                        message = json.dumps({"mode": "online", "text": rec_result["text"]})
                        websocket.send_msg.put(message)
        time.sleep(0.005)
                        await websocket.send(message)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)