zhifu gao
2023-10-16 1d7bbbffb6a024a33859b48a7a656d0455dc0be1
funasr/runtime/python/websocket/funasr_wss_client.py
@@ -12,7 +12,6 @@
import logging
SUPPORT_AUDIO_TYPE_SETS = ['.wav', '.pcm']
logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
@@ -30,6 +29,14 @@
                    type=str,
                    default="5, 10, 5",
                    help="chunk")
parser.add_argument("--encoder_chunk_look_back",
                    type=int,
                    default=4,
                    help="number of chunks to lookback for encoder self-attention")
parser.add_argument("--decoder_chunk_look_back",
                    type=int,
                    default=1,
                    help="number of encoder chunks to lookback for decoder cross-attention")
parser.add_argument("--chunk_interval",
                    type=int,
                    default=10,
@@ -100,7 +107,8 @@
                    input=True,
                    frames_per_buffer=CHUNK)
    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": args.encoder_chunk_look_back,
                          "decoder_chunk_look_back": args.decoder_chunk_look_back, "chunk_interval": args.chunk_interval,
                          "wav_name": "microphone", "is_speaking": True})
    #voices.put(message)
    await websocket.send(message)
@@ -197,7 +205,7 @@
    text_print_2pass_online = ""
    text_print_2pass_offline = ""
    if args.output_dir is not None:
        ibest_writer = open(os.path.join(args.output_dir, "text.{}".format(id)), "w+", encoding="utf-8")
        ibest_writer = open(os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8")
    else:
        ibest_writer = None
    try:
@@ -205,6 +213,7 @@
        
            meg = await websocket.recv()
            meg = json.loads(meg)
            # print(meg)
            wav_name = meg.get("wav_name", "demo")
            text = meg["text"]
@@ -219,10 +228,14 @@
                print("\rpid" + str(id) + ": " + text_print)
            elif meg["mode"] == "offline":
                text_print += "{}".format(text)
                text_print = text_print[-args.words_max_print:]
                # text_print = text_print[-args.words_max_print:]
                # os.system('clear')
                print("\rpid" + str(id) + ": " + text_print)
                offline_msg_done=True
                print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
                if ("is_final" in meg and meg["is_final"]==False):
                    offline_msg_done = True
                if not "is_final" in meg:
                    offline_msg_done = True
            else:
                if meg["mode"] == "2pass-online":
                    text_print_2pass_online += "{}".format(text)
@@ -295,9 +308,7 @@
            wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
            wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
            audio_type = os.path.splitext(wav_path)[-1].lower()
            # if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
            #    raise NotImplementedError(
            #        f'Not supported audio type: {audio_type}')
        total_len = len(wavs)
        if total_len >= args.thread_num: