| | |
| | | from modelscope.utils.constant import Tasks |
| | | from modelscope.utils.logger import get_logger |
| | | import logging |
| | | import tracemalloc |
| | | tracemalloc.start() |
| | | |
| | | logger = get_logger(log_level=logging.CRITICAL) |
| | | logger.setLevel(logging.CRITICAL) |
| | | |
| | | |
| | | websocket_users = set() #维护客户端列表 |
| | | |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument("--host", |
| | |
| | | args = parser.parse_args() |
| | | |
| | | print("model loading") |
| | | voices = Queue() |
| | | speek = Queue() |
| | | |
| | | |
| | | # vad |
| | | inference_pipeline_vad = pipeline( |
| | |
| | | |
| | | |
| | | async def ws_serve(websocket, path): |
| | | global voices |
| | | #speek = Queue() |
| | | frames = [] # 存储所有的帧数据 |
| | | buffer = [] # 存储缓存中的帧数据(最多两个片段) |
| | | RECORD_NUM = 0 |
| | | global websocket_users |
| | | speech_start, speech_end = False, False |
| | | # 调用asr函数 |
| | | websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 |
| | | websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 |
| | | websocket_users.add(websocket) |
| | | ss = threading.Thread(target=asr, args=(websocket,)) |
| | | ss.start() |
| | | |
| | | try: |
| | | async for message in websocket: |
| | | voices.put(message) |
| | | #voices.put(message) |
| | | #print("put") |
| | | except websockets.exceptions.ConnectionClosedError as e: |
| | | print('Connection closed with exception:', e) |
| | | #await websocket.send("123") |
| | | buffer.append(message) |
| | | if len(buffer) > 2: |
| | | buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 |
| | | |
| | | if speech_start: |
| | | frames.append(message) |
| | | RECORD_NUM += 1 |
| | | speech_start_i, speech_end_i = vad(message) |
| | | #print(speech_start_i, speech_end_i) |
| | | if speech_start_i: |
| | | speech_start = speech_start_i |
| | | frames = [] |
| | | frames.extend(buffer) # 把之前2个语音数据快加入 |
| | | if speech_end_i or RECORD_NUM > 300: |
| | | speech_start = False |
| | | audio_in = b"".join(frames) |
| | | websocket.speek.put(audio_in) |
| | | frames = [] # 清空所有的帧数据 |
| | | buffer = [] # 清空缓存中的帧数据(最多两个片段) |
| | | RECORD_NUM = 0 |
| | | if not websocket.send_msg.empty(): |
| | | await websocket.send(websocket.send_msg.get()) |
| | | websocket.send_msg.task_done() |
| | | |
| | | |
| | | except websockets.ConnectionClosed: |
| | | print("ConnectionClosed...", websocket_users) # 链接断开 |
| | | websocket_users.remove(websocket) |
| | | except websockets.InvalidState: |
| | | print("InvalidState...") # 无效状态 |
| | | except Exception as e: |
| | | print('Exception occurred:', e) |
| | | |
| | | start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) |
| | | print("Exception:", e) |
| | | |
| | | |
| | | def vad(data): # 推理 |
| | | def asr(websocket): # ASR推理 |
| | | global inference_pipeline2 |
| | | global param_dict_punc |
| | | global websocket_users |
| | | while websocket in websocket_users: |
| | | if not websocket.speek.empty(): |
| | | audio_in = websocket.speek.get() |
| | | websocket.speek.task_done() |
| | | if len(audio_in) > 0: |
| | | rec_result = inference_pipeline_asr(audio_in=audio_in) |
| | | if inference_pipeline_punc is not None and 'text' in rec_result: |
| | | rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc) |
| | | results = (rec_result["text"] if "text" in rec_result else rec_result) |
| | | websocket.send_msg.put(results) # 存入发送队列 直接调用send发送不了 |
| | | |
| | | time.sleep(0.1) |
| | | |
| | | def vad(data): # VAD推理 |
| | | global vad_pipline, param_dict_vad |
| | | #print(type(data)) |
| | | # print(param_dict_vad) |
| | |
| | | speech_end = True |
| | | return speech_start, speech_end |
| | | |
| | | def asr(): # 推理 |
| | | global inference_pipeline2 |
| | | global speek, param_dict_punc |
| | | while True: |
| | | while not speek.empty(): |
| | | audio_in = speek.get() |
| | | speek.task_done() |
| | | if len(audio_in) > 0: |
| | | rec_result = inference_pipeline_asr(audio_in=audio_in) |
| | | if inference_pipeline_punc is not None and 'text' in rec_result: |
| | | rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc) |
| | | print(rec_result["text"] if "text" in rec_result else rec_result) |
| | | time.sleep(0.1) |
| | | time.sleep(0.1) |
| | | |
| | | |
| | | def main(): # 推理 |
| | | frames = [] # 存储所有的帧数据 |
| | | buffer = [] # 存储缓存中的帧数据(最多两个片段) |
| | | # silence_count = 0 # 统计连续静音的次数 |
| | | # speech_detected = False # 标记是否检测到语音 |
| | | RECORD_NUM = 0 |
| | | global voices |
| | | global speek |
| | | speech_start, speech_end = False, False |
| | | while True: |
| | | while not voices.empty(): |
| | | |
| | | data = voices.get() |
| | | #print("队列排队数",voices.qsize()) |
| | | voices.task_done() |
| | | buffer.append(data) |
| | | if len(buffer) > 2: |
| | | buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 |
| | | |
| | | if speech_start: |
| | | frames.append(data) |
| | | RECORD_NUM += 1 |
| | | speech_start_i, speech_end_i = vad(data) |
| | | # print(speech_start_i, speech_end_i) |
| | | if speech_start_i: |
| | | speech_start = speech_start_i |
| | | # if not speech_detected: |
| | | # print("检测到人声...") |
| | | # speech_detected = True # 标记为检测到语音 |
| | | frames = [] |
| | | frames.extend(buffer) # 把之前2个语音数据快加入 |
| | | # silence_count = 0 # 重置静音次数 |
| | | if speech_end_i or RECORD_NUM > 300: |
| | | # silence_count += 1 # 增加静音次数 |
| | | # speech_end = speech_end_i |
| | | speech_start = False |
| | | # if RECORD_NUM > 300: #这里 50 可根据需求改为合适的数据快数量 |
| | | # print("说话结束或者超过设置最长时间...") |
| | | audio_in = b"".join(frames) |
| | | #asrt = threading.Thread(target=asr,args=(audio_in,)) |
| | | #asrt.start() |
| | | speek.put(audio_in) |
| | | #rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑 |
| | | frames = [] # 清空所有的帧数据 |
| | | buffer = [] # 清空缓存中的帧数据(最多两个片段) |
| | | # silence_count = 0 # 统计连续静音的次数清零 |
| | | # speech_detected = False # 标记是否检测到语音 |
| | | RECORD_NUM = 0 |
| | | time.sleep(0.01) |
| | | time.sleep(0.01) |
| | | |
| | | |
| | | |
| | | s = threading.Thread(target=main) |
| | | s.start() |
| | | s = threading.Thread(target=asr) |
| | | s.start() |
| | | |
| | | start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) |
| | | asyncio.get_event_loop().run_until_complete(start_server) |
| | | asyncio.get_event_loop().run_forever() |