From 97eb2bd6568eda1fdce93e074f2cb83412163385 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 25 九月 2023 17:45:03 +0800
Subject: [PATCH] update paraformer online python websocket code

---
 funasr/runtime/python/websocket/funasr_wss_client.py |   11 ++++++++++-
 funasr/runtime/python/websocket/funasr_wss_server.py |    8 ++++++--
 funasr/runtime/python/websocket/funasr_client_api.py |    5 +++--
 3 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/funasr/runtime/python/websocket/funasr_client_api.py b/funasr/runtime/python/websocket/funasr_client_api.py
index aa573c0..d0992ff 100644
--- a/funasr/runtime/python/websocket/funasr_client_api.py
+++ b/funasr/runtime/python/websocket/funasr_client_api.py
@@ -51,7 +51,8 @@
         stride = int(60 *  chunk_size[1]/  chunk_interval / 1000 * 16000 * 2)
         chunk_num = (len(audio_bytes) - 1) // stride + 1
        
-        message = json.dumps({"mode":  mode, "chunk_size":  chunk_size, "chunk_interval":  chunk_interval,
+        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": 4,
+                              "decoder_chunk_look_back": 1, "chunk_interval": args.chunk_interval, 
                               "wav_name": wav_name, "is_speaking": True})
  
         self.websocket.send(message)
@@ -131,4 +132,4 @@
     print("text",text)
  
     
-            
\ No newline at end of file
+            
diff --git a/funasr/runtime/python/websocket/funasr_wss_client.py b/funasr/runtime/python/websocket/funasr_wss_client.py
index bed0081..f4f35bb 100644
--- a/funasr/runtime/python/websocket/funasr_wss_client.py
+++ b/funasr/runtime/python/websocket/funasr_wss_client.py
@@ -29,6 +29,14 @@
                     type=str,
                     default="5, 10, 5",
                     help="chunk")
+parser.add_argument("--encoder_chunk_look_back",
+                    type=int,
+                    default=4,
+                    help="number of chunks to lookback for encoder self-attention")
+parser.add_argument("--decoder_chunk_look_back",
+                    type=int,
+                    default=1,
+                    help="number of encoder chunks to lookback for decoder cross-attention")
 parser.add_argument("--chunk_interval",
                     type=int,
                     default=10,
@@ -99,7 +107,8 @@
                     input=True,
                     frames_per_buffer=CHUNK)
 
-    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
+    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": args.encoder_chunk_look_back,
+                          "decoder_chunk_look_back": args.decoder_chunk_look_back, "chunk_interval": args.chunk_interval, 
                           "wav_name": "microphone", "is_speaking": True})
     #voices.put(message)
     await websocket.send(message)
diff --git a/funasr/runtime/python/websocket/funasr_wss_server.py b/funasr/runtime/python/websocket/funasr_wss_server.py
index 57f7584..b74f225 100644
--- a/funasr/runtime/python/websocket/funasr_wss_server.py
+++ b/funasr/runtime/python/websocket/funasr_wss_server.py
@@ -103,8 +103,8 @@
     model=args.asr_model_online,
     ngpu=args.ngpu,
     ncpu=args.ncpu,
-    model_revision='v1.0.4',
-    update_model='v1.0.4',
+    model_revision='v1.0.7',
+    update_model='v1.0.7',
     mode='paraformer_streaming')
 
 print("model loaded! only support one client at the same time now!!!!")
@@ -159,6 +159,10 @@
                     websocket.wav_name = messagejson.get("wav_name")
                 if "chunk_size" in messagejson:
                     websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+                if "encoder_chunk_look_back" in messagejson:
+                    websocket.param_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
+                if "decoder_chunk_look_back" in messagejson:
+                    websocket.param_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
                 if "mode" in messagejson:
                     websocket.mode = messagejson["mode"]
             if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):

--
Gitblit v1.9.1