From 61567c6d3bc8723243ffabd2d88a227a2d416e89 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 十一月 2023 12:53:47 +0800
Subject: [PATCH] Merge branch 'main' into dev_gzf_funasr2

---
 runtime/docs/SDK_advanced_guide_online.md        |    4 +
 runtime/docs/SDK_advanced_guide_online_zh.md     |    7 ++-
 runtime/docs/SDK_advanced_guide_offline_en.md    |    4 +
 runtime/docs/SDK_advanced_guide_offline_zh.md    |    7 ++-
 runtime/python/websocket/README.md               |    2 
 funasr/quick_start.md                            |    4 +-
 funasr/datasets/large_datasets/utils/padding.py  |   18 +++-----
 runtime/docs/SDK_advanced_guide_offline.md       |    4 +
 runtime/docs/SDK_advanced_guide_offline_en_zh.md |    4 +
 funasr/quick_start_zh.md                         |    4 +-
 runtime/python/websocket/funasr_wss_server.py    |    4 +-
 11 files changed, 36 insertions(+), 26 deletions(-)

diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
index 20ba7a3..26c6e84 100644
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -32,7 +32,7 @@
             batch[data_name] = tensor_pad
             batch[data_name + "_lengths"] = tensor_lengths
 
-    # DHA, EAHC NOT INCLUDED
+    # SAC LABEL INCLUDE
     if "hotword_indxs" in batch:
         # if hotword indxs in batch
         # use it to slice hotwords out
@@ -41,28 +41,25 @@
         text = batch['text']
         text_lengths = batch['text_lengths']
         hotword_indxs = batch['hotword_indxs']
-        num_hw = sum([int(i) for i in batch['hotword_indxs_lengths'] if i != 1]) // 2
-        B, t1 = text.shape
+        dha_pad = torch.ones_like(text) * -1
+        _, t1 = text.shape
         t1 += 1  # TODO: as parameter which is same as predictor_bias
-        ideal_attn = torch.zeros(B, t1, num_hw+1)
         nth_hw = 0
         for b, (hotword_indx, one_text, length) in enumerate(zip(hotword_indxs, text, text_lengths)):
-            ideal_attn[b][:,-1] = 1
+            dha_pad[b][:length] = 8405
             if hotword_indx[0] != -1:
                 start, end = int(hotword_indx[0]), int(hotword_indx[1])
                 hotword = one_text[start: end+1]
                 hotword_list.append(hotword)
                 hotword_lengths.append(end-start+1)
-                ideal_attn[b][start:end+1, nth_hw] = 1
-                ideal_attn[b][start:end+1, -1] = 0
+                dha_pad[b][start: end+1] = one_text[start: end+1]
                 nth_hw += 1
                 if len(hotword_indx) == 4 and hotword_indx[2] != -1:
                     # the second hotword if exist
                     start, end = int(hotword_indx[2]), int(hotword_indx[3])
                     hotword_list.append(one_text[start: end+1])
                     hotword_lengths.append(end-start+1)
-                    ideal_attn[b][start:end+1, nth_hw-1] = 1
-                    ideal_attn[b][start:end+1, -1] = 0
+                    dha_pad[b][start: end+1] = one_text[start: end+1]
                     nth_hw += 1
         hotword_list.append(torch.tensor([1]))
         hotword_lengths.append(1)
@@ -71,8 +68,7 @@
                                 padding_value=0)
         batch["hotword_pad"] = hotword_pad
         batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
-        batch['ideal_attn'] = ideal_attn
+        batch['dha_pad'] = dha_pad
         del batch['hotword_indxs']
         del batch['hotword_indxs_lengths']
-
     return keys, batch
diff --git a/funasr/quick_start.md b/funasr/quick_start.md
index 0b316c0..4566b87 100644
--- a/funasr/quick_start.md
+++ b/funasr/quick_start.md
@@ -16,7 +16,7 @@
 #### Server Deployment
 
 ```shell
-cd funasr/runtime/python/websocket
+cd runtime/python/websocket
 python funasr_wss_server.py --port 10095
 ```
 
@@ -161,4 +161,4 @@
 cd egs/aishell/paraformer
 . ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
 ```
-More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
\ No newline at end of file
+More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)
diff --git a/funasr/quick_start_zh.md b/funasr/quick_start_zh.md
index 4e35866..64fe870 100644
--- a/funasr/quick_start_zh.md
+++ b/funasr/quick_start_zh.md
@@ -17,7 +17,7 @@
 
 ##### 鏈嶅姟绔儴缃�
 ```shell
-cd funasr/runtime/python/websocket
+cd runtime/python/websocket
 python funasr_wss_server.py --port 10095
 ```
 
@@ -161,4 +161,4 @@
 . ./run.sh --CUDA_VISIBLE_DEVICES="0,1" --gpu_num=2
 ```
 
-鏇村渚嬪瓙鍙互鍙傝�冿紙[鐐瑰嚮姝ゅ](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)锛�
\ No newline at end of file
+鏇村渚嬪瓙鍙互鍙傝�冿紙[鐐瑰嚮姝ゅ](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)锛�
diff --git a/runtime/docs/SDK_advanced_guide_offline.md b/runtime/docs/SDK_advanced_guide_offline.md
index 87e4ed6..130eee7 100644
--- a/runtime/docs/SDK_advanced_guide_offline.md
+++ b/runtime/docs/SDK_advanced_guide_offline.md
@@ -94,7 +94,9 @@
 --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
 --itn-dir modelscope model ID or local model path.
 --port: Port number that the server listens on. Default is 10095.
---decoder-thread-num: Number of inference threads that the server starts. Default is 8.
+--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
+--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model. 
+        The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl锛宻et 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
diff --git a/runtime/docs/SDK_advanced_guide_offline_en.md b/runtime/docs/SDK_advanced_guide_offline_en.md
index 1e53422..4985984 100644
--- a/runtime/docs/SDK_advanced_guide_offline_en.md
+++ b/runtime/docs/SDK_advanced_guide_offline_en.md
@@ -73,7 +73,9 @@
 --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
 --itn-dir modelscope model ID or local model path.
 --port: Port number that the server listens on. Default is 10095.
---decoder-thread-num: Number of inference threads that the server starts. Default is 8.
+--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
+--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model. 
+        The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl锛宻et 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
diff --git a/runtime/docs/SDK_advanced_guide_offline_en_zh.md b/runtime/docs/SDK_advanced_guide_offline_en_zh.md
index 2cedccd..57a793d 100644
--- a/runtime/docs/SDK_advanced_guide_offline_en_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_en_zh.md
@@ -158,7 +158,9 @@
 --punc-quant   True涓洪噺鍖朠UNC妯″瀷锛孎alse涓洪潪閲忓寲PUNC妯″瀷锛岄粯璁ゆ槸True
 --itn-dir modelscope model ID 鎴栬�� 鏈湴妯″瀷璺緞
 --port  鏈嶅姟绔洃鍚殑绔彛鍙凤紝榛樿涓� 10095
---decoder-thread-num  鏈嶅姟绔惎鍔ㄧ殑鎺ㄧ悊绾跨▼鏁帮紝榛樿涓� 8
+--decoder-thread-num  鏈嶅姟绔嚎绋嬫睜涓暟(鏀寔鐨勬渶澶у苟鍙戣矾鏁�)锛岄粯璁や负 8
+--model-thread-num  姣忚矾璇嗗埆鐨勫唴閮ㄧ嚎绋嬫暟(鎺у埗ONNX妯″瀷鐨勫苟琛�)锛岄粯璁や负 1锛�
+                    鍏朵腑寤鸿 decoder-thread-num*model-thread-num 绛変簬鎬荤嚎绋嬫暟
 --io-thread-num  鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
 --certfile  ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt锛屽鏋滈渶瑕佸叧闂璼sl锛屽弬鏁拌缃负0
 --keyfile   ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
diff --git a/runtime/docs/SDK_advanced_guide_offline_zh.md b/runtime/docs/SDK_advanced_guide_offline_zh.md
index fe1f2f6..299b27d 100644
--- a/runtime/docs/SDK_advanced_guide_offline_zh.md
+++ b/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -175,11 +175,14 @@
 --lm-dir modelscope model ID 鎴栬�� 鏈湴妯″瀷璺緞
 --itn-dir modelscope model ID 鎴栬�� 鏈湴妯″瀷璺緞
 --port  鏈嶅姟绔洃鍚殑绔彛鍙凤紝榛樿涓� 10095
---decoder-thread-num  鏈嶅姟绔惎鍔ㄧ殑鎺ㄧ悊绾跨▼鏁帮紝榛樿涓� 8
+--decoder-thread-num  鏈嶅姟绔嚎绋嬫睜涓暟(鏀寔鐨勬渶澶у苟鍙戣矾鏁�)锛岄粯璁や负 8
+--model-thread-num  姣忚矾璇嗗埆鐨勫唴閮ㄧ嚎绋嬫暟(鎺у埗ONNX妯″瀷鐨勫苟琛�)锛岄粯璁や负 1锛�
+                    鍏朵腑寤鸿 decoder-thread-num*model-thread-num 绛変簬鎬荤嚎绋嬫暟
 --io-thread-num  鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
 --certfile  ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt锛屽鏋滈渶瑕佸叧闂璼sl锛屽弬鏁拌缃负0
 --keyfile   ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
---hotword   鐑瘝鏂囦欢璺緞锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡锛氱儹璇� 鏉冮噸(渚嬪:闃块噷宸村反 20)锛屽鏋滃鎴风鎻愪緵鐑瘝锛屽垯涓庡鎴风鎻愪緵鐨勭儹璇嶅悎骞朵竴璧蜂娇鐢ㄣ��
+--hotword   鐑瘝鏂囦欢璺緞锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡锛氱儹璇� 鏉冮噸(渚嬪:闃块噷宸村反 20)锛�
+            濡傛灉瀹㈡埛绔彁渚涚儹璇嶏紝鍒欎笌瀹㈡埛绔彁渚涚殑鐑瘝鍚堝苟涓�璧蜂娇鐢紝鏈嶅姟绔儹璇嶅叏灞�鐢熸晥锛屽鎴风鐑瘝鍙拡瀵瑰搴斿鎴风鐢熸晥銆�
 ```
 
 ### 鍏抽棴FunASR鏈嶅姟
diff --git a/runtime/docs/SDK_advanced_guide_online.md b/runtime/docs/SDK_advanced_guide_online.md
index ea52c55..384b13b 100644
--- a/runtime/docs/SDK_advanced_guide_online.md
+++ b/runtime/docs/SDK_advanced_guide_online.md
@@ -111,7 +111,9 @@
 --punc-quant: True for quantized PUNC model, False for non-quantized PUNC model. Default is True.
 --itn-dir modelscope model ID or local model path.
 --port: Port number that the server listens on. Default is 10095.
---decoder-thread-num: Number of inference threads that the server starts. Default is 8.
+--decoder-thread-num: The number of thread pools on the server side that can handle concurrent requests. The default value is 8.
+--model-thread-num: The number of internal threads for each recognition route to control the parallelism of the ONNX model. 
+        The default value is 1. It is recommended that decoder-thread-num * model-thread-num equals the total number of threads.
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl锛宻et 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
diff --git a/runtime/docs/SDK_advanced_guide_online_zh.md b/runtime/docs/SDK_advanced_guide_online_zh.md
index d8da63f..bafc329 100644
--- a/runtime/docs/SDK_advanced_guide_online_zh.md
+++ b/runtime/docs/SDK_advanced_guide_online_zh.md
@@ -120,11 +120,14 @@
 --punc-quant   True涓洪噺鍖朠UNC妯″瀷锛孎alse涓洪潪閲忓寲PUNC妯″瀷锛岄粯璁ゆ槸True
 --itn-dir modelscope model ID 鎴栬�� 鏈湴妯″瀷璺緞
 --port  鏈嶅姟绔洃鍚殑绔彛鍙凤紝榛樿涓� 10095
---decoder-thread-num  鏈嶅姟绔惎鍔ㄧ殑鎺ㄧ悊绾跨▼鏁帮紝榛樿涓� 8
+--decoder-thread-num  鏈嶅姟绔嚎绋嬫睜涓暟(鏀寔鐨勬渶澶у苟鍙戣矾鏁�)锛岄粯璁や负 8
+--model-thread-num  姣忚矾璇嗗埆鐨勫唴閮ㄧ嚎绋嬫暟(鎺у埗ONNX妯″瀷鐨勫苟琛�)锛岄粯璁や负 1锛�
+                    鍏朵腑寤鸿 decoder-thread-num*model-thread-num 绛変簬鎬荤嚎绋嬫暟
 --io-thread-num  鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
 --certfile  ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt锛屽鏋滈渶瑕佸叧闂璼sl锛屽弬鏁拌缃负0
 --keyfile   ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
---hotword   鐑瘝鏂囦欢璺緞锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡锛氱儹璇� 鏉冮噸(渚嬪:闃块噷宸村反 20)锛屽鏋滃鎴风鎻愪緵鐑瘝锛屽垯涓庡鎴风鎻愪緵鐨勭儹璇嶅悎骞朵竴璧蜂娇鐢ㄣ��
+--hotword   鐑瘝鏂囦欢璺緞锛屾瘡琛屼竴涓儹璇嶏紝鏍煎紡锛氱儹璇� 鏉冮噸(渚嬪:闃块噷宸村反 20)锛�
+            濡傛灉瀹㈡埛绔彁渚涚儹璇嶏紝鍒欎笌瀹㈡埛绔彁渚涚殑鐑瘝鍚堝苟涓�璧蜂娇鐢紝鏈嶅姟绔儹璇嶅叏灞�鐢熸晥锛屽鎴风鐑瘝鍙拡瀵瑰搴斿鎴风鐢熸晥銆�
 ```
 
 ### 鍏抽棴FunASR鏈嶅姟
diff --git a/runtime/python/websocket/README.md b/runtime/python/websocket/README.md
index d50c8e1..304008d 100644
--- a/runtime/python/websocket/README.md
+++ b/runtime/python/websocket/README.md
@@ -16,7 +16,7 @@
 ### Install the requirements for server
 
 ```shell
-cd funasr/runtime/python/websocket
+cd runtime/python/websocket
 pip install -r requirements_server.txt
 ```
 
diff --git a/runtime/python/websocket/funasr_wss_server.py b/runtime/python/websocket/funasr_wss_server.py
index 716e281..22d2a7f 100644
--- a/runtime/python/websocket/funasr_wss_server.py
+++ b/runtime/python/websocket/funasr_wss_server.py
@@ -53,13 +53,13 @@
                     help="cpu cores")
 parser.add_argument("--certfile",
                     type=str,
-                    default="../ssl_key/server.crt",
+                    default="../../ssl_key/server.crt",
                     required=False,
                     help="certfile for ssl")
 
 parser.add_argument("--keyfile",
                     type=str,
-                    default="../ssl_key/server.key",
+                    default="../../ssl_key/server.key",
                     required=False,
                     help="keyfile for ssl")
 args = parser.parse_args()

--
Gitblit v1.9.1