From c7fc6149b3c5c2de3107c4f1d4983309882d1a1a Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 07 六月 2023 14:57:49 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/runtime/python/websocket/wss_srv_asr.py | 41 +++++++++++++++++++++++++++++++----------
1 files changed, 31 insertions(+), 10 deletions(-)
diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
index 6460fbf..3810cd6 100644
--- a/funasr/runtime/python/websocket/wss_srv_asr.py
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -58,16 +58,36 @@
model=args.asr_model_online,
ngpu=args.ngpu,
ncpu=args.ncpu,
- model_revision='v1.0.6',
+ model_revision='v1.0.4',
+ update_model='v1.0.4',
mode='paraformer_streaming')
-print("model loaded")
+print("model loaded! only support one client at the same time now!!!!")
+async def ws_reset(websocket):
+ print("ws reset now, total num is ",len(websocket_users))
+ websocket.param_dict_asr_online = {"cache": dict()}
+ websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
+ websocket.param_dict_asr_online["is_final"]=True
+ audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
+ inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+ inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
+ await websocket.close()
+
+
+async def clear_websocket():
+ for websocket in websocket_users:
+ await ws_reset(websocket)
+ websocket_users.clear()
+
+
+
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
+ await clear_websocket()
websocket_users.add(websocket)
websocket.param_dict_asr = {}
websocket.param_dict_asr_online = {"cache": dict()}
@@ -75,7 +95,7 @@
websocket.param_dict_punc = {'cache': list()}
websocket.vad_pre_idx = 0
speech_start = False
- speech_end_i = False
+ speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
print("new user connected", flush=True)
@@ -104,7 +124,7 @@
# asr online
frames_asr_online.append(message)
- websocket.param_dict_asr_online["is_final"] = speech_end_i
+ websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr_online)
@@ -114,14 +134,14 @@
frames_asr.append(message)
# vad online
speech_start_i, speech_end_i = await async_vad(websocket, message)
- if speech_start_i:
+ if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
- if speech_end_i or not websocket.is_speaking:
+ if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
@@ -139,7 +159,8 @@
except websockets.ConnectionClosed:
- print("ConnectionClosed...", websocket_users)
+ print("ConnectionClosed...", websocket_users,flush=True)
+ await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
@@ -151,15 +172,15 @@
segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
- speech_start = False
- speech_end = False
+ speech_start = -1
+ speech_end = -1
if len(segments_result) == 0 or len(segments_result["text"]) > 1:
return speech_start, speech_end
if segments_result["text"][0][0] != -1:
speech_start = segments_result["text"][0][0]
if segments_result["text"][0][1] != -1:
- speech_end = True
+ speech_end = segments_result["text"][0][1]
return speech_start, speech_end
--
Gitblit v1.9.1