From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 runtime/python/websocket/funasr_wss_client.py |  107 ++++++++++++++++++++++++++++++++++++-----------------
 1 files changed, 72 insertions(+), 35 deletions(-)

diff --git a/runtime/python/websocket/funasr_wss_client.py b/runtime/python/websocket/funasr_wss_client.py
index 66b3ce0..b30964a 100644
--- a/runtime/python/websocket/funasr_wss_client.py
+++ b/runtime/python/websocket/funasr_wss_client.py
@@ -29,6 +29,14 @@
                     type=str,
                     default="5, 10, 5",
                     help="chunk")
+parser.add_argument("--encoder_chunk_look_back",
+                    type=int,
+                    default=4,
+                    help="chunk")
+parser.add_argument("--decoder_chunk_look_back",
+                    type=int,
+                    default=0,
+                    help="chunk")
 parser.add_argument("--chunk_interval",
                     type=int,
                     default=10,
@@ -41,6 +49,10 @@
                     type=str,
                     default=None,
                     help="audio_in")
+parser.add_argument("--audio_fs",
+                    type=int,
+                    default=16000,
+                    help="audio_fs")
 parser.add_argument("--send_without_sleep",
                     action="store_true",
                     default=True,
@@ -109,25 +121,36 @@
     fst_dict = {}
     hotword_msg = ""
     if args.hotword.strip() != "":
-        f_scp = open(args.hotword)
-        hot_lines = f_scp.readlines()
-        for line in hot_lines:
-            words = line.strip().split(" ")
-            if len(words) < 2:
-                print("Please checkout format of hotwords")
-                continue
-            try:
-                fst_dict[" ".join(words[:-1])] = int(words[-1])
-            except ValueError:
-                print("Please checkout format of hotwords")
-        hotword_msg=json.dumps(fst_dict)
+        if os.path.exists(args.hotword):
+            f_scp = open(args.hotword)
+            hot_lines = f_scp.readlines()
+            for line in hot_lines:
+                words = line.strip().split(" ")
+                if len(words) < 2:
+                    print("Please checkout format of hotwords")
+                    continue
+                try:
+                    fst_dict[" ".join(words[:-1])] = int(words[-1])
+                except ValueError:
+                    print("Please checkout format of hotwords")
+            hotword_msg = json.dumps(fst_dict)
+        else:
+            hotword_msg = args.hotword
 
-    use_itn=True
+    use_itn = True
     if args.use_itn == 0:
         use_itn=False
     
-    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
-                          "wav_name": "microphone", "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
+    message = json.dumps({"mode": args.mode,
+                          "chunk_size": args.chunk_size,
+                          "chunk_interval": args.chunk_interval,
+                          "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                          "decoder_chunk_look_back": args.decoder_chunk_look_back,
+                          "wav_name": "microphone",
+                          "is_speaking": True,
+                          "hotwords": hotword_msg,
+                          "itn": use_itn,
+                          })
     #voices.put(message)
     await websocket.send(message)
     while True:
@@ -150,21 +173,24 @@
     fst_dict = {}
     hotword_msg = ""
     if args.hotword.strip() != "":
-        f_scp = open(args.hotword)
-        hot_lines = f_scp.readlines()
-        for line in hot_lines:
-            words = line.strip().split(" ")
-            if len(words) < 2:
-                print("Please checkout format of hotwords")
-                continue
-            try:
-                fst_dict[" ".join(words[:-1])] = int(words[-1])
-            except ValueError:
-                print("Please checkout format of hotwords")
-        hotword_msg=json.dumps(fst_dict)
+        if os.path.exists(args.hotword):
+            f_scp = open(args.hotword)
+            hot_lines = f_scp.readlines()
+            for line in hot_lines:
+                words = line.strip().split(" ")
+                if len(words) < 2:
+                    print("Please checkout format of hotwords")
+                    continue
+                try:
+                    fst_dict[" ".join(words[:-1])] = int(words[-1])
+                except ValueError:
+                    print("Please checkout format of hotwords")
+            hotword_msg = json.dumps(fst_dict)
+        else:
+            hotword_msg = args.hotword
         print (hotword_msg)
 
-    sample_rate = 16000
+    sample_rate = args.audio_fs
     wav_format = "pcm"
     use_itn=True
     if args.use_itn == 0:
@@ -188,20 +214,28 @@
                 params = wav_file.getparams()
                 sample_rate = wav_file.getframerate()
                 frames = wav_file.readframes(wav_file.getnframes())
-                audio_bytes = bytes(frames)
+                audio_bytes = bytes(frames)        
         else:
             wav_format = "others"
             with open(wav_path, "rb") as f:
                 audio_bytes = f.read()
 
-        # stride = int(args.chunk_size/1000*16000*2)
-        stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * 16000 * 2)
+        stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * sample_rate * 2)
         chunk_num = (len(audio_bytes) - 1) // stride + 1
         # print(stride)
 
         # send first time
-        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio_fs":sample_rate,
-                          "wav_name": wav_name, "wav_format": wav_format, "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
+        message = json.dumps({"mode": args.mode,
+                              "chunk_size": args.chunk_size,
+                              "chunk_interval": args.chunk_interval,
+                              "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                              "decoder_chunk_look_back": args.decoder_chunk_look_back,
+                              "audio_fs":sample_rate,
+                              "wav_name": wav_name,
+                              "wav_format": wav_format,
+                              "is_speaking": True,
+                              "hotwords": hotword_msg,
+                              "itn": use_itn})
 
         #voices.put(message)
         await websocket.send(message)
@@ -253,6 +287,7 @@
             wav_name = meg.get("wav_name", "demo")
             text = meg["text"]
             timestamp=""
+            offline_msg_done = meg.get("is_final", False)
             if "timestamp" in meg:
                 timestamp = meg["timestamp"]
 
@@ -262,7 +297,9 @@
                 else:
                     text_write_line = "{}\t{}\n".format(wav_name, text)
                 ibest_writer.write(text_write_line)
-                
+
+            if 'mode' not in meg:
+                continue
             if meg["mode"] == "online":
                 text_print += "{}".format(text)
                 text_print = text_print[-args.words_max_print:]
@@ -289,7 +326,7 @@
                 text_print = text_print[-args.words_max_print:]
                 os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
-                offline_msg_done=True
+                # offline_msg_done=True
 
     except Exception as e:
             print("Exception:", e)

--
Gitblit v1.9.1