游雁
2023-03-24 d14855ef20b691130a49228f2afc0889fe0b5905
funasr/runtime/python/websocket/ASR_server.py
@@ -62,7 +62,7 @@
    mode='online',
    ngpu=args.ngpu,
)
param_dict_vad = {'in_cache': dict(), "is_final": False}
# param_dict_vad = {'in_cache': dict(), "is_final": False}
  
# asr
param_dict_asr = {}
@@ -74,7 +74,7 @@
    ngpu=args.ngpu,
)
if args.punc_model != "":
    param_dict_punc = {'cache': list()}
    # param_dict_punc = {'cache': list()}
    inference_pipeline_punc = pipeline(
        task=Tasks.punctuation,
        model=args.punc_model,
@@ -96,6 +96,8 @@
    global websocket_users
    speech_start, speech_end = False, False
    # 调用asr函数
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
    websocket.param_dict_punc = {'cache': list()}
    websocket.speek = Queue()  #websocket 添加进队列对象 让asr读取语音数据包
    websocket.send_msg = Queue()   #websocket 添加个队列对象  让ws发送消息到客户端
    websocket_users.add(websocket)
@@ -114,7 +116,7 @@
            if speech_start:
                frames.append(message)
                RECORD_NUM += 1
            speech_start_i, speech_end_i = vad(message)
            speech_start_i, speech_end_i = vad(message, websocket)
            #print(speech_start_i, speech_end_i)
            if speech_start_i:
                speech_start = speech_start_i
@@ -143,7 +145,7 @@
def asr(websocket):  # ASR推理
        global inference_pipeline2
        global param_dict_punc
        # global param_dict_punc
        global websocket_users
        while websocket in  websocket_users:
            if not websocket.speek.empty():
@@ -152,17 +154,18 @@
                if len(audio_in) > 0:
                    rec_result = inference_pipeline_asr(audio_in=audio_in)
                    if inference_pipeline_punc is not None and 'text' in rec_result:
                        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
                        rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc)
                    # print(rec_result)
                    if "text" in rec_result:
                        websocket.send_msg.put(rec_result["text"]) # 存入发送队列  直接调用send发送不了
               
            time.sleep(0.1)
def vad(data):  # VAD推理
def vad(data, websocket):  # VAD推理
    global vad_pipline, param_dict_vad
    #print(type(data))
    # print(param_dict_vad)
    segments_result = inference_pipeline_vad(audio_in=data, param_dict=param_dict_vad)
    segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad)
    # print(segments_result)
    # print(param_dict_vad)
    speech_start = False