From 6a59f94278b4d05f5b7715b11b6dd0bfffc28cce Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 15 六月 2023 17:12:15 +0800
Subject: [PATCH] Merge branch 'main' into dev_dzh

---
 funasr/runtime/python/websocket/wss_srv_asr.py |   46 +++++++++++++++++++++++++++++++++-------------
 1 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/funasr/runtime/python/websocket/wss_srv_asr.py b/funasr/runtime/python/websocket/wss_srv_asr.py
index 71c97e6..09f2305 100644
--- a/funasr/runtime/python/websocket/wss_srv_asr.py
+++ b/funasr/runtime/python/websocket/wss_srv_asr.py
@@ -35,8 +35,6 @@
     task=Tasks.voice_activity_detection,
     model=args.vad_model,
     model_revision=None,
-    output_dir=None,
-    batch_size=1,
     mode='online',
     ngpu=args.ngpu,
     ncpu=args.ncpu,
@@ -58,15 +56,36 @@
     model=args.asr_model_online,
     ngpu=args.ngpu,
     ncpu=args.ncpu,
-    model_revision='v1.0.4')
+    model_revision='v1.0.4',
+    update_model='v1.0.4',
+    mode='paraformer_streaming')
 
-print("model loaded")
+print("model loaded! only support one client at the same time now!!!!")
 
+async def ws_reset(websocket):
+    print("ws reset now, total num is ",len(websocket_users))
+    websocket.param_dict_asr_online = {"cache": dict()}
+    websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
+    websocket.param_dict_asr_online["is_final"]=True
+    # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
+    # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
+    # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
+    await websocket.close()
+    
+    
+async def clear_websocket():
+   for websocket in websocket_users:
+       await ws_reset(websocket)
+   websocket_users.clear()
+ 
+ 
+       
 async def ws_serve(websocket, path):
     frames = []
     frames_asr = []
     frames_asr_online = []
     global websocket_users
+    await clear_websocket()
     websocket_users.add(websocket)
     websocket.param_dict_asr = {}
     websocket.param_dict_asr_online = {"cache": dict()}
@@ -74,7 +93,7 @@
     websocket.param_dict_punc = {'cache': list()}
     websocket.vad_pre_idx = 0
     speech_start = False
-    speech_end_i = False
+    speech_end_i = -1
     websocket.wav_name = "microphone"
     websocket.mode = "2pass"
     print("new user connected", flush=True)
@@ -103,7 +122,7 @@
         
                     # asr online
                     frames_asr_online.append(message)
-                    websocket.param_dict_asr_online["is_final"] = speech_end_i
+                    websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
                     if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
                         if websocket.mode == "2pass" or websocket.mode == "online":
                             audio_in = b"".join(frames_asr_online)
@@ -113,14 +132,14 @@
                         frames_asr.append(message)
                     # vad online
                     speech_start_i, speech_end_i = await async_vad(websocket, message)
-                    if speech_start_i:
+                    if speech_start_i != -1:
                         speech_start = True
                         beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                         frames_pre = frames[-beg_bias:]
                         frames_asr = []
                         frames_asr.extend(frames_pre)
                 # asr punc offline
-                if speech_end_i or not websocket.is_speaking:
+                if speech_end_i != -1 or not websocket.is_speaking:
                     # print("vad end point")
                     if websocket.mode == "2pass" or websocket.mode == "offline":
                         audio_in = b"".join(frames_asr)
@@ -138,7 +157,8 @@
 
      
     except websockets.ConnectionClosed:
-        print("ConnectionClosed...", websocket_users)
+        print("ConnectionClosed...", websocket_users,flush=True)
+        await ws_reset(websocket)
         websocket_users.remove(websocket)
     except websockets.InvalidState:
         print("InvalidState...")
@@ -150,15 +170,15 @@
 
     segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
 
-    speech_start = False
-    speech_end = False
+    speech_start = -1
+    speech_end = -1
     
     if len(segments_result) == 0 or len(segments_result["text"]) > 1:
         return speech_start, speech_end
     if segments_result["text"][0][0] != -1:
         speech_start = segments_result["text"][0][0]
     if segments_result["text"][0][1] != -1:
-        speech_end = True
+        speech_end = segments_result["text"][0][1]
     return speech_start, speech_end
 
 
@@ -207,4 +227,4 @@
 else:
     start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 asyncio.get_event_loop().run_until_complete(start_server)
-asyncio.get_event_loop().run_forever()
\ No newline at end of file
+asyncio.get_event_loop().run_forever()

--
Gitblit v1.9.1