From 4f224c88068b66bcb6f81570da59d99c9bba8288 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 一月 2024 16:02:14 +0800
Subject: [PATCH] python-websocket funasr1.0 (#1310)

---
 runtime/python/websocket/funasr_wss_client.py |   89 ++++++++++++++++++++++++++++++--------------
 1 files changed, 60 insertions(+), 29 deletions(-)

diff --git a/runtime/python/websocket/funasr_wss_client.py b/runtime/python/websocket/funasr_wss_client.py
index a2d8889..b30964a 100644
--- a/runtime/python/websocket/funasr_wss_client.py
+++ b/runtime/python/websocket/funasr_wss_client.py
@@ -29,6 +29,14 @@
                     type=str,
                     default="5, 10, 5",
                     help="chunk")
+parser.add_argument("--encoder_chunk_look_back",
+                    type=int,
+                    default=4,
+                    help="chunk")
+parser.add_argument("--decoder_chunk_look_back",
+                    type=int,
+                    default=0,
+                    help="chunk")
 parser.add_argument("--chunk_interval",
                     type=int,
                     default=10,
@@ -113,25 +121,36 @@
     fst_dict = {}
     hotword_msg = ""
     if args.hotword.strip() != "":
-        f_scp = open(args.hotword)
-        hot_lines = f_scp.readlines()
-        for line in hot_lines:
-            words = line.strip().split(" ")
-            if len(words) < 2:
-                print("Please checkout format of hotwords")
-                continue
-            try:
-                fst_dict[" ".join(words[:-1])] = int(words[-1])
-            except ValueError:
-                print("Please checkout format of hotwords")
-        hotword_msg=json.dumps(fst_dict)
+        if os.path.exists(args.hotword):
+            f_scp = open(args.hotword)
+            hot_lines = f_scp.readlines()
+            for line in hot_lines:
+                words = line.strip().split(" ")
+                if len(words) < 2:
+                    print("Please checkout format of hotwords")
+                    continue
+                try:
+                    fst_dict[" ".join(words[:-1])] = int(words[-1])
+                except ValueError:
+                    print("Please checkout format of hotwords")
+            hotword_msg = json.dumps(fst_dict)
+        else:
+            hotword_msg = args.hotword
 
-    use_itn=True
+    use_itn = True
     if args.use_itn == 0:
         use_itn=False
     
-    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
-                          "wav_name": "microphone", "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
+    message = json.dumps({"mode": args.mode,
+                          "chunk_size": args.chunk_size,
+                          "chunk_interval": args.chunk_interval,
+                          "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                          "decoder_chunk_look_back": args.decoder_chunk_look_back,
+                          "wav_name": "microphone",
+                          "is_speaking": True,
+                          "hotwords": hotword_msg,
+                          "itn": use_itn,
+                          })
     #voices.put(message)
     await websocket.send(message)
     while True:
@@ -154,18 +173,21 @@
     fst_dict = {}
     hotword_msg = ""
     if args.hotword.strip() != "":
-        f_scp = open(args.hotword)
-        hot_lines = f_scp.readlines()
-        for line in hot_lines:
-            words = line.strip().split(" ")
-            if len(words) < 2:
-                print("Please checkout format of hotwords")
-                continue
-            try:
-                fst_dict[" ".join(words[:-1])] = int(words[-1])
-            except ValueError:
-                print("Please checkout format of hotwords")
-        hotword_msg=json.dumps(fst_dict)
+        if os.path.exists(args.hotword):
+            f_scp = open(args.hotword)
+            hot_lines = f_scp.readlines()
+            for line in hot_lines:
+                words = line.strip().split(" ")
+                if len(words) < 2:
+                    print("Please checkout format of hotwords")
+                    continue
+                try:
+                    fst_dict[" ".join(words[:-1])] = int(words[-1])
+                except ValueError:
+                    print("Please checkout format of hotwords")
+            hotword_msg = json.dumps(fst_dict)
+        else:
+            hotword_msg = args.hotword
         print (hotword_msg)
 
     sample_rate = args.audio_fs
@@ -203,8 +225,17 @@
         # print(stride)
 
         # send first time
-        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio_fs":sample_rate,
-                          "wav_name": wav_name, "wav_format": wav_format, "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
+        message = json.dumps({"mode": args.mode,
+                              "chunk_size": args.chunk_size,
+                              "chunk_interval": args.chunk_interval,
+                              "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                              "decoder_chunk_look_back": args.decoder_chunk_look_back,
+                              "audio_fs":sample_rate,
+                              "wav_name": wav_name,
+                              "wav_format": wav_format,
+                              "is_speaking": True,
+                              "hotwords": hotword_msg,
+                              "itn": use_itn})
 
         #voices.put(message)
         await websocket.send(message)

--
Gitblit v1.9.1