From 9dad49c3a1f2495384bab4cc3763e4f8a461da00 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 五月 2023 00:20:19 +0800
Subject: [PATCH] websocket new version for offline 2pass send bytes

---
 funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/top_level.txt        |    1 
 funasr/runtime/python/websocket/ws_server_offline.py                        |   51 +++--
 funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/SOURCES.txt          |   17 +
 funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/dependency_links.txt |    1 
 funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/requires.txt         |   10 +
 funasr/utils/modelscope_utils.py                                            |   16 +
 funasr/runtime/python/websocket/ws_client.py                                |   63 ++++---
 funasr/runtime/python/websocket/ws_server_2pass.py                          |   71 ++++---
 funasr/runtime/python/websocket/ws_server_online.py                         |  116 ++++++------
 funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/PKG-INFO             |  190 +++++++++++++++++++++
 10 files changed, 395 insertions(+), 141 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/PKG-INFO b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/PKG-INFO
new file mode 100644
index 0000000..18fc04c
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/PKG-INFO
@@ -0,0 +1,190 @@
+Metadata-Version: 2.1
+Name: funasr-onnx
+Version: 0.1.0
+Summary: FunASR: A Fundamental End-to-End Speech Recognition Toolkit
+Home-page: https://github.com/alibaba-damo-academy/FunASR.git
+Author: Speech Lab of DAMO Academy, Alibaba Group
+Author-email: funasr@list.alibaba-inc.com
+License: MIT
+Keywords: funasr,asr
+Platform: Any
+Classifier: Programming Language :: Python :: 3.6
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Description-Content-Type: text/markdown
+
+# ONNXRuntime-python
+
+
+## Install `funasr_onnx`
+
+install from pip
+```shell
+pip install -U funasr_onnx
+# For the users in China, you could install with the command:
+# pip install -U funasr_onnx -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
+or install from source code
+
+```shell
+git clone https://github.com/alibaba/FunASR.git && cd FunASR
+cd funasr/runtime/python/onnxruntime
+pip install -e ./
+# For the users in China, you could install with the command:
+# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+```
+
+## Inference with runtime
+
+### Speech Recognition
+#### Paraformer
+ ```python
+from funasr_onnx import Paraformer
+from pathlib import Path
+
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1, quantize=True)
+
+wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'.format(Path.home())]
+
+result = model(wav_path)
+print(result)
+ ```
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
+- `batch_size`: `1` (Default), the batch size duration inference
+- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
+- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
+- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
+
+Input: wav formt file, support formats: `str, np.ndarray, List[str]`
+
+Output: `List[str]`: recognition result
+
+#### Paraformer-online
+
+### Voice Activity Detection
+#### FSMN-VAD
+```python
+from funasr_onnx import Fsmn_vad
+from pathlib import Path
+
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
+
+model = Fsmn_vad(model_dir)
+
+result = model(wav_path)
+print(result)
+```
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
+- `batch_size`: `1` (Default), the batch size duration inference
+- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
+- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
+- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
+
+Input: wav formt file, support formats: `str, np.ndarray, List[str]`
+
+Output: `List[str]`: recognition result
+
+
+#### FSMN-VAD-online
+```python
+from funasr_onnx import Fsmn_vad_online
+import soundfile
+from pathlib import Path
+
+model_dir = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav'.format(Path.home())
+
+model = Fsmn_vad_online(model_dir)
+
+
+##online vad
+speech, sample_rate = soundfile.read(wav_path)
+speech_length = speech.shape[0]
+#
+sample_offset = 0
+step = 1600
+param_dict = {'in_cache': []}
+for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+    if sample_offset + step >= speech_length - 1:
+        step = speech_length - sample_offset
+        is_final = True
+    else:
+        is_final = False
+    param_dict['is_final'] = is_final
+    segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
+                            param_dict=param_dict)
+    if segments_result:
+        print(segments_result)
+```
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
+- `batch_size`: `1` (Default), the batch size duration inference
+- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
+- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
+- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
+
+Input: wav formt file, support formats: `str, np.ndarray, List[str]`
+
+Output: `List[str]`: recognition result
+
+
+### Punctuation Restoration
+#### CT-Transformer
+```python
+from funasr_onnx import CT_Transformer
+
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model = CT_Transformer(model_dir)
+
+text_in="璺ㄥ娌虫祦鏄吇鑲叉部宀镐汉姘戠殑鐢熷懡涔嬫簮闀挎湡浠ユ潵涓哄府鍔╀笅娓稿湴鍖洪槻鐏惧噺鐏句腑鏂规妧鏈汉鍛樺湪涓婃父鍦板尯鏋佷负鎭跺姡鐨勮嚜鐒舵潯浠朵笅鍏嬫湇宸ㄥぇ鍥伴毦鐢氳嚦鍐掔潃鐢熷懡鍗遍櫓鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒跺嚒鏄腑鏂硅兘鍋氱殑鎴戜滑閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛﹁鍒掑拰璁鸿瘉鍏奸【涓婁笅娓哥殑鍒╃泭"
+result = model(text_in)
+print(result[0])
+```
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
+- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
+- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
+- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
+
+Input: `str`, raw text of asr result
+
+Output: `List[str]`: recognition result
+
+
+#### CT-Transformer-online
+```python
+from funasr_onnx import CT_Transformer_VadRealtime
+
+model_dir = "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model = CT_Transformer_VadRealtime(model_dir)
+
+text_in  = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦>闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
+vads = text_in.split("|")
+rec_result_all=""
+param_dict = {"cache": []}
+for vad in vads:
+    result = model(vad, param_dict=param_dict)
+    rec_result_all += result[0]
+
+print(rec_result_all)
+```
+- `model_dir`: model_name in modelscope or local path downloaded from modelscope. If the local path is set, it should contain `model.onnx`, `config.yaml`, `am.mvn`
+- `device_id`: `-1` (Default), infer on CPU. If you want to infer with GPU, set it to gpu_id (Please make sure that you have install the onnxruntime-gpu)
+- `quantize`: `False` (Default), load the model of `model.onnx` in `model_dir`. If set `True`, load the model of `model_quant.onnx` in `model_dir`
+- `intra_op_num_threads`: `4` (Default), sets the number of threads used for intraop parallelism on CPU
+
+Input: `str`, raw text of asr result
+
+Output: `List[str]`: recognition result
+
+## Performance benchmark
+
+Please ref to [benchmark](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/runtime/python/benchmark_onnx.md)
+
+## Acknowledge
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We partially refer [SWHL](https://github.com/RapidAI/RapidASR) for onnxruntime (only for paraformer model).
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/SOURCES.txt b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/SOURCES.txt
new file mode 100644
index 0000000..e759e27
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/SOURCES.txt
@@ -0,0 +1,17 @@
+README.md
+setup.py
+funasr_onnx/__init__.py
+funasr_onnx/paraformer_bin.py
+funasr_onnx/punc_bin.py
+funasr_onnx/vad_bin.py
+funasr_onnx.egg-info/PKG-INFO
+funasr_onnx.egg-info/SOURCES.txt
+funasr_onnx.egg-info/dependency_links.txt
+funasr_onnx.egg-info/requires.txt
+funasr_onnx.egg-info/top_level.txt
+funasr_onnx/utils/__init__.py
+funasr_onnx/utils/e2e_vad.py
+funasr_onnx/utils/frontend.py
+funasr_onnx/utils/postprocess_utils.py
+funasr_onnx/utils/timestamp_utils.py
+funasr_onnx/utils/utils.py
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/dependency_links.txt b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/dependency_links.txt
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/requires.txt b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/requires.txt
new file mode 100644
index 0000000..cf777b4
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/requires.txt
@@ -0,0 +1,10 @@
+librosa
+onnxruntime>=1.7.0
+scipy
+numpy>=1.19.3
+typeguard
+kaldi-native-fbank
+PyYAML>=5.1.2
+funasr
+modelscope
+onnx
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/top_level.txt b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/top_level.txt
new file mode 100644
index 0000000..de41eb9
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx.egg-info/top_level.txt
@@ -0,0 +1 @@
+funasr_onnx
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index 7ae44df..45c745a 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -85,9 +85,8 @@
                     input=True,
                     frames_per_buffer=CHUNK)
 
-    message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+    message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": "microphone", "is_speaking": True})
     voices.put(message)
-    is_speaking = True
     while True:
 
         data = stream.read(CHUNK)
@@ -146,9 +145,6 @@
             sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
             await asyncio.sleep(sleep_duration)
 
-    is_finished = True
-    message = json.dumps({"is_finished": is_finished})
-    voices.put(message)
 
 async def ws_send():
     global voices
@@ -241,29 +237,9 @@
 
 
 if __name__ == '__main__':
-    # calculate the number of wavs for each preocess
-    if args.audio_in.endswith(".scp"):
-        f_scp = open(args.audio_in)
-        wavs = f_scp.readlines()
-    else:
-        wavs = [args.audio_in]
-    total_len=len(wavs)
-    if total_len>=args.test_thread_num:
-         chunk_size=int((total_len)/args.test_thread_num)
-         remain_wavs=total_len-chunk_size*args.test_thread_num
-    else:
-         chunk_size=0
-    
     process_list = []
-    chunk_begin=0
     for i in range(args.test_thread_num):
-        now_chunk_size= chunk_size
-        if remain_wavs>0:
-            now_chunk_size=chunk_size+1
-            remain_wavs=remain_wavs-1
-        # process i handle wavs at chunk_begin and size of now_chunk_size
-        p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
-        chunk_begin=chunk_begin+now_chunk_size
+        p = Process(target=one_thread,args=(i, 0, 0))
         p.start()
         process_list.append(p)
 
@@ -271,5 +247,38 @@
         p.join()
 
     print('end')
- 
+
+#
+# if __name__ == '__main__':
+#     # calculate the number of wavs for each preocess
+#     if args.audio_in.endswith(".scp"):
+#         f_scp = open(args.audio_in)
+#         wavs = f_scp.readlines()
+#     else:
+#         wavs = [args.audio_in]
+#     total_len=len(wavs)
+#     if total_len>=args.test_thread_num:
+#          chunk_size=int((total_len)/args.test_thread_num)
+#          remain_wavs=total_len-chunk_size*args.test_thread_num
+#     else:
+#          chunk_size=0
+#
+#     process_list = []
+#     chunk_begin=0
+#     for i in range(args.test_thread_num):
+#         now_chunk_size= chunk_size
+#         if remain_wavs>0:
+#             now_chunk_size=chunk_size+1
+#             remain_wavs=remain_wavs-1
+#         # process i handle wavs at chunk_begin and size of now_chunk_size
+#         p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
+#         chunk_begin=chunk_begin+now_chunk_size
+#         p.start()
+#         process_list.append(p)
+#
+#     for i in process_list:
+#         p.join()
+#
+#     print('end')
+#
 
diff --git a/funasr/runtime/python/websocket/ws_server_2pass.py b/funasr/runtime/python/websocket/ws_server_2pass.py
index ced67ff..186197a 100644
--- a/funasr/runtime/python/websocket/ws_server_2pass.py
+++ b/funasr/runtime/python/websocket/ws_server_2pass.py
@@ -74,47 +74,54 @@
     websocket.param_dict_punc = {'cache': list()}
     websocket.vad_pre_idx = 0
     speech_start = False
+    websocket.wav_name = "microphone"
+    print("new user connected", flush=True)
 
     try:
         async for message in websocket:
-            message = json.loads(message)
-            is_finished = message["is_finished"]
-            if not is_finished:
-                audio = bytes(message['audio'], 'ISO-8859-1')
-                frames.append(audio)
-                duration_ms = len(audio)//32
-                websocket.vad_pre_idx += duration_ms
-
-                is_speaking = message["is_speaking"]
-                websocket.param_dict_vad["is_final"] = not is_speaking
-                websocket.param_dict_asr_online["is_final"] = not is_speaking
-                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
-                websocket.wav_name = message.get("wav_name", "demo")
-                # asr online
-                frames_asr_online.append(audio)
-                if len(frames_asr_online) % message["chunk_interval"] == 0:
-                    audio_in = b"".join(frames_asr_online)
-                    await async_asr_online(websocket, audio_in)
-                    frames_asr_online = []
-                if speech_start:
-                    frames_asr.append(audio)
-                # vad online
-                speech_start_i, speech_end_i = await async_vad(websocket, audio)
-                if speech_start_i:
-                    speech_start = True
-                    beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
-                    frames_pre = frames[-beg_bias:]
-                    frames_asr = []
-                    frames_asr.extend(frames_pre)
+            if isinstance(message, str):
+                messagejson = json.loads(message)
+        
+                if "is_speaking" in messagejson:
+                    websocket.is_speaking = messagejson["is_speaking"]
+                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+                if "chunk_interval" in messagejson:
+                    websocket.chunk_interval = messagejson["chunk_interval"]
+                if "wav_name" in messagejson:
+                    websocket.wav_name = messagejson.get("wav_name")
+                if "chunk_size" in messagejson:
+                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
+                if not isinstance(message, str):
+                    frames.append(message)
+                    duration_ms = len(message)//32
+                    websocket.vad_pre_idx += duration_ms
+        
+                    # asr online
+                    frames_asr_online.append(message)
+                    if len(frames_asr_online) % websocket.chunk_interval == 0:
+                        audio_in = b"".join(frames_asr_online)
+                        await async_asr_online(websocket, audio_in)
+                        frames_asr_online = []
+                    if speech_start:
+                        frames_asr.append(message)
+                    # vad online
+                    speech_start_i, speech_end_i = await async_vad(websocket, message)
+                    if speech_start_i:
+                        speech_start = True
+                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+                        frames_pre = frames[-beg_bias:]
+                        frames_asr = []
+                        frames_asr.extend(frames_pre)
                 # asr punc offline
-                if speech_end_i or not is_speaking:
+                if speech_end_i or not websocket.is_speaking:
                     audio_in = b"".join(frames_asr)
                     await async_asr(websocket, audio_in)
                     frames_asr = []
                     speech_start = False
                     frames_asr_online = []
                     websocket.param_dict_asr_online = {"cache": dict()}
-                    if not is_speaking:
+                    if not websocket.is_speaking:
                         websocket.vad_pre_idx = 0
                         frames = []
                         websocket.param_dict_vad = {'in_cache': dict()}
@@ -168,7 +175,7 @@
         audio_in = load_bytes(audio_in)
         rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                    param_dict=websocket.param_dict_asr_online)
-        if websocket.param_dict_asr_online["is_final"]:
+        if websocket.param_dict_asr_online.get("is_final", False):
             websocket.param_dict_asr_online["cache"] = dict()
         if "text" in rec_result:
             if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
diff --git a/funasr/runtime/python/websocket/ws_server_offline.py b/funasr/runtime/python/websocket/ws_server_offline.py
index 15578f6..1fcc246 100644
--- a/funasr/runtime/python/websocket/ws_server_offline.py
+++ b/funasr/runtime/python/websocket/ws_server_offline.py
@@ -65,35 +65,40 @@
     websocket.param_dict_punc = {'cache': list()}
     websocket.vad_pre_idx = 0
     speech_start = False
+    websocket.wav_name = "microphone"
+    print("new user connected", flush=True)
 
     try:
         async for message in websocket:
-            message = json.loads(message)
-            is_finished = message["is_finished"]
-            if not is_finished:
-                audio = bytes(message['audio'], 'ISO-8859-1')
-                frames.append(audio)
-                duration_ms = len(audio)//32
-                websocket.vad_pre_idx += duration_ms
-
-                is_speaking = message["is_speaking"]
-                websocket.param_dict_vad["is_final"] = not is_speaking
-                websocket.wav_name = message.get("wav_name", "demo")
-                if speech_start:
-                    frames_asr.append(audio)
-                speech_start_i, speech_end_i = await async_vad(websocket, audio)
-                if speech_start_i:
-                    speech_start = True
-                    beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
-                    frames_pre = frames[-beg_bias:]
-                    frames_asr = []
-                    frames_asr.extend(frames_pre)
-                if speech_end_i or not is_speaking:
+            if isinstance(message, str):
+                messagejson = json.loads(message)
+                if "is_speaking" in messagejson:
+                    websocket.is_speaking = messagejson["is_speaking"]
+                    websocket.param_dict_vad["is_final"] = not websocket.is_speaking
+                if "wav_name" in messagejson:
+                    websocket.wav_name = messagejson.get("wav_name")
+            
+            if len(frames_asr) > 0 or not isinstance(message, str):
+                if not isinstance(message, str):
+                    frames.append(message)
+                    duration_ms = len(message)//32
+                    websocket.vad_pre_idx += duration_ms
+    
+                    if speech_start:
+                        frames_asr.append(message)
+                    speech_start_i, speech_end_i = await async_vad(websocket, message)
+                    if speech_start_i:
+                        speech_start = True
+                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
+                        frames_pre = frames[-beg_bias:]
+                        frames_asr = []
+                        frames_asr.extend(frames_pre)
+                if speech_end_i or not websocket.is_speaking:
                     audio_in = b"".join(frames_asr)
                     await async_asr(websocket, audio_in)
                     frames_asr = []
                     speech_start = False
-                    if not is_speaking:
+                    if not websocket.is_speaking:
                         websocket.vad_pre_idx = 0
                         frames = []
                         websocket.param_dict_vad = {'in_cache': dict()}
@@ -133,7 +138,7 @@
                 
                 rec_result = inference_pipeline_asr(audio_in=audio_in,
                                                     param_dict=websocket.param_dict_asr)
-                # print(rec_result)
+                print(rec_result)
                 if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
                     rec_result = inference_pipeline_punc(text_in=rec_result['text'],
                                                          param_dict=websocket.param_dict_punc)
diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py
index 44edf98..a35b127 100644
--- a/funasr/runtime/python/websocket/ws_server_online.py
+++ b/funasr/runtime/python/websocket/ws_server_online.py
@@ -26,74 +26,72 @@
 print("model loading")
 
 inference_pipeline_asr_online = pipeline(
-    task=Tasks.auto_speech_recognition,
-    model=args.asr_model_online,
-    ngpu=args.ngpu,
-    ncpu=args.ncpu,
-    model_revision='v1.0.4')
+	task=Tasks.auto_speech_recognition,
+	model=args.asr_model_online,
+	ngpu=args.ngpu,
+	ncpu=args.ncpu,
+	model_revision='v1.0.4')
 
 print("model loaded")
 
 
 
 async def ws_serve(websocket, path):
-    frames_asr_online = []
-    global websocket_users
-    websocket_users.add(websocket)
-    websocket.param_dict_asr_online = {"cache": dict()}
-    print("new user connected",flush=True)
-    try:
-        async for message in websocket:
-            
- 
-            if isinstance(message,str):
-              messagejson = json.loads(message)
-               
-              if "is_speaking" in messagejson:
-                  websocket.is_speaking = messagejson["is_speaking"]  
-                  websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
-              if "is_finished" in messagejson:
-                  websocket.is_speaking = False
-                  websocket.param_dict_asr_online["is_final"] = True
-              if "chunk_interval" in messagejson:
-                  websocket.chunk_interval=messagejson["chunk_interval"]
-              if "wav_name" in messagejson:
-                  websocket.wav_name = messagejson.get("wav_name", "demo")
-              if "chunk_size" in messagejson:
-                  websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
-            # if has bytes in buffer or message is bytes
-            if len(frames_asr_online)>0 or not isinstance(message,str):
-               if not isinstance(message,str):
-                 frames_asr_online.append(message)
-               if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
-                    audio_in = b"".join(frames_asr_online)
-                    if not websocket.is_speaking:
-                       #padding 0.5s at end gurantee that asr engine can fire out last word
-                       audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
-                    await async_asr_online(websocket,audio_in)
-                    frames_asr_online = []
+	frames_asr_online = []
+	global websocket_users
+	websocket_users.add(websocket)
+	websocket.param_dict_asr_online = {"cache": dict()}
+	websocket.wav_name = "microphone"
+	print("new user connected",flush=True)
+	try:
+		async for message in websocket:
+			
+			
+			if isinstance(message, str):
+				messagejson = json.loads(message)
+				
+				if "is_speaking" in messagejson:
+					websocket.is_speaking = messagejson["is_speaking"]
+					websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+				if "chunk_interval" in messagejson:
+					websocket.chunk_interval=messagejson["chunk_interval"]
+				if "wav_name" in messagejson:
+					websocket.wav_name = messagejson.get("wav_name")
+				if "chunk_size" in messagejson:
+					websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+			# if has bytes in buffer or message is bytes
+			if len(frames_asr_online) > 0 or not isinstance(message, str):
+				if not isinstance(message,str):
+					frames_asr_online.append(message)
+				if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
+					audio_in = b"".join(frames_asr_online)
+					# if not websocket.is_speaking:
+						#padding 0.5s at end gurantee that asr engine can fire out last word
+						# audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
+					await async_asr_online(websocket,audio_in)
+					frames_asr_online = []
+	
+	
+	except websockets.ConnectionClosed:
+		print("ConnectionClosed...", websocket_users)
+		websocket_users.remove(websocket)
+	except websockets.InvalidState:
+		print("InvalidState...")
+	except Exception as e:
+		print("Exception:", e)
 
-     
-    except websockets.ConnectionClosed:
-        print("ConnectionClosed...", websocket_users)
-        websocket_users.remove(websocket)
-    except websockets.InvalidState:
-        print("InvalidState...")
-    except Exception as e:
-        print("Exception:", e)
 
- 
 async def async_asr_online(websocket,audio_in):
-            if len(audio_in) > 0:
-                audio_in = load_bytes(audio_in)
-                rec_result = inference_pipeline_asr_online(audio_in=audio_in,
-                                                           param_dict=websocket.param_dict_asr_online)
-                if websocket.param_dict_asr_online["is_final"]:
-                    websocket.param_dict_asr_online["cache"] = dict()
-                if "text" in rec_result:
-                    if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
-                        message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
-                        await websocket.send(message)
+	if len(audio_in) > 0:
+		audio_in = load_bytes(audio_in)
+		rec_result = inference_pipeline_asr_online(audio_in=audio_in,
+		                                           param_dict=websocket.param_dict_asr_online)
+		if websocket.param_dict_asr_online.get("is_final", False):
+			websocket.param_dict_asr_online["cache"] = dict()
+		if "text" in rec_result:
+			if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
+				message = json.dumps({"mode": "online", "text": rec_result["text"], "wav_name": websocket.wav_name})
+				await websocket.send(message)
 
 
 
diff --git a/funasr/utils/modelscope_utils.py b/funasr/utils/modelscope_utils.py
new file mode 100644
index 0000000..9712e09
--- /dev/null
+++ b/funasr/utils/modelscope_utils.py
@@ -0,0 +1,16 @@
+import os
+from modelscope.hub.snapshot_download import snapshot_download
+
+
+def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"):
+	model_dir = "/Users/zhifu/test_modelscope_pipeline/FSMN-VAD"
+	
+	cache_root = os.path.dirname(model_dir)
+	dst_dir_root = os.path.join(cache_root, ".cache")
+	dst = os.path.join(dst_dir_root, model_name)
+	dst_dir = os.path.dirname(dst)
+	os.makedirs(dst_dir, exist_ok=True)
+	if not os.path.exists(dst):
+		os.symlink(model_dir, dst)
+	
+	model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
\ No newline at end of file

--
Gitblit v1.9.1