游雁
2023-05-13 9dad49c3a1f2495384bab4cc3763e4f8a461da00
funasr/runtime/python/websocket/ws_server_2pass.py
@@ -74,47 +74,54 @@
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
    websocket.wav_name = "microphone"
    print("new user connected", flush=True)
    try:
        async for message in websocket:
            message = json.loads(message)
            is_finished = message["is_finished"]
            if not is_finished:
                audio = bytes(message['audio'], 'ISO-8859-1')
                frames.append(audio)
                duration_ms = len(audio)//32
                websocket.vad_pre_idx += duration_ms
                is_speaking = message["is_speaking"]
                websocket.param_dict_vad["is_final"] = not is_speaking
                websocket.param_dict_asr_online["is_final"] = not is_speaking
                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
                websocket.wav_name = message.get("wav_name", "demo")
                # asr online
                frames_asr_online.append(audio)
                if len(frames_asr_online) % message["chunk_interval"] == 0:
                    audio_in = b"".join(frames_asr_online)
                    await async_asr_online(websocket, audio_in)
                    frames_asr_online = []
                if speech_start:
                    frames_asr.append(audio)
                # vad online
                speech_start_i, speech_end_i = await async_vad(websocket, audio)
                if speech_start_i:
                    speech_start = True
                    beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                    frames_pre = frames[-beg_bias:]
                    frames_asr = []
                    frames_asr.extend(frames_pre)
            if isinstance(message, str):
                messagejson = json.loads(message)
                if "is_speaking" in messagejson:
                    websocket.is_speaking = messagejson["is_speaking"]
                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
                if "chunk_interval" in messagejson:
                    websocket.chunk_interval = messagejson["chunk_interval"]
                if "wav_name" in messagejson:
                    websocket.wav_name = messagejson.get("wav_name")
                if "chunk_size" in messagejson:
                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
                if not isinstance(message, str):
                    frames.append(message)
                    duration_ms = len(message)//32
                    websocket.vad_pre_idx += duration_ms
                    # asr online
                    frames_asr_online.append(message)
                    if len(frames_asr_online) % websocket.chunk_interval == 0:
                        audio_in = b"".join(frames_asr_online)
                        await async_asr_online(websocket, audio_in)
                        frames_asr_online = []
                    if speech_start:
                        frames_asr.append(message)
                    # vad online
                    speech_start_i, speech_end_i = await async_vad(websocket, message)
                    if speech_start_i:
                        speech_start = True
                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                        frames_pre = frames[-beg_bias:]
                        frames_asr = []
                        frames_asr.extend(frames_pre)
                # asr punc offline
                if speech_end_i or not is_speaking:
                if speech_end_i or not websocket.is_speaking:
                    audio_in = b"".join(frames_asr)
                    await async_asr(websocket, audio_in)
                    frames_asr = []
                    speech_start = False
                    frames_asr_online = []
                    websocket.param_dict_asr_online = {"cache": dict()}
                    if not is_speaking:
                    if not websocket.is_speaking:
                        websocket.vad_pre_idx = 0
                        frames = []
                        websocket.param_dict_vad = {'in_cache': dict()}
@@ -168,7 +175,7 @@
        audio_in = load_bytes(audio_in)
        rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                   param_dict=websocket.param_dict_asr_online)
        if websocket.param_dict_asr_online["is_final"]:
        if websocket.param_dict_asr_online.get("is_final", False):
            websocket.param_dict_asr_online["cache"] = dict()
        if "text" in rec_result:
            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":