From b9bcf1f093c3053fdc4e2cf4a1d38e27bbf429fb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 19 十月 2023 14:03:48 +0800
Subject: [PATCH] docs

---
 funasr/runtime/python/websocket/funasr_wss_client.py |   18 ++++++++++++++++--
 1 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/funasr/runtime/python/websocket/funasr_wss_client.py b/funasr/runtime/python/websocket/funasr_wss_client.py
index 72121f7..f4f35bb 100644
--- a/funasr/runtime/python/websocket/funasr_wss_client.py
+++ b/funasr/runtime/python/websocket/funasr_wss_client.py
@@ -29,6 +29,14 @@
                     type=str,
                     default="5, 10, 5",
                     help="chunk")
+parser.add_argument("--encoder_chunk_look_back",
+                    type=int,
+                    default=4,
+                    help="number of chunks to lookback for encoder self-attention")
+parser.add_argument("--decoder_chunk_look_back",
+                    type=int,
+                    default=1,
+                    help="number of encoder chunks to lookback for decoder cross-attention")
 parser.add_argument("--chunk_interval",
                     type=int,
                     default=10,
@@ -99,7 +107,8 @@
                     input=True,
                     frames_per_buffer=CHUNK)
 
-    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
+    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                          "decoder_chunk_look_back": args.decoder_chunk_look_back, "chunk_interval": args.chunk_interval, 
                           "wav_name": "microphone", "is_speaking": True})
     #voices.put(message)
     await websocket.send(message)
@@ -204,6 +213,7 @@
         
             meg = await websocket.recv()
             meg = json.loads(meg)
+            # print(meg)
             wav_name = meg.get("wav_name", "demo")
             text = meg["text"]
 
@@ -221,7 +231,11 @@
                 # text_print = text_print[-args.words_max_print:]
                 # os.system('clear')
                 print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
-                offline_msg_done = True
+                if ("is_final" in meg and meg["is_final"]==False):
+                    offline_msg_done = True
+                
+                if not "is_final" in meg:
+                    offline_msg_done = True
             else:
                 if meg["mode"] == "2pass-online":
                     text_print_2pass_online += "{}".format(text)

--
Gitblit v1.9.1