From fa2f52caeaf6ad4b7624f53d4d9207b89edea5a6 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期三, 05 七月 2023 10:21:38 +0800
Subject: [PATCH] Update SDK_advanced_guide_offline_zh.md

---
 funasr/runtime/python/websocket/wss_client_asr.py |   86 +++++++++++++++++++++++++++++-------------
 1 files changed, 59 insertions(+), 27 deletions(-)

diff --git a/funasr/runtime/python/websocket/wss_client_asr.py b/funasr/runtime/python/websocket/wss_client_asr.py
index dec598a..dcd9576 100644
--- a/funasr/runtime/python/websocket/wss_client_asr.py
+++ b/funasr/runtime/python/websocket/wss_client_asr.py
@@ -40,12 +40,12 @@
                     help="audio_in")
 parser.add_argument("--send_without_sleep",
                     action="store_true",
-                    default=False,
+                    default=True,
                     help="if audio_in is set, send_without_sleep")
-parser.add_argument("--test_thread_num",
+parser.add_argument("--thread_num",
                     type=int,
                     default=1,
-                    help="test_thread_num")
+                    help="thread_num")
 parser.add_argument("--words_max_print",
                     type=int,
                     default=10000,
@@ -71,7 +71,8 @@
 from queue import Queue
 
 voices = Queue()
-
+offline_msg_done=False
+ 
 ibest_writer = None
 if args.output_dir is not None:
     writer = DatadirWriter(args.output_dir)
@@ -118,9 +119,11 @@
         wavs = wavs[chunk_begin:chunk_begin + chunk_size]
     for wav in wavs:
         wav_splits = wav.strip().split()
+ 
         wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
         wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
-
+        if not len(wav_path.strip())>0:
+           continue
         if wav_path.endswith(".pcm"):
             with open(wav_path, "rb") as f:
                 audio_bytes = f.read()
@@ -142,22 +145,38 @@
         # send first time
         message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
                               "wav_name": wav_name, "is_speaking": True})
-        voices.put(message)
+        #voices.put(message)
+        await websocket.send(message)
         is_speaking = True
         for i in range(chunk_num):
 
             beg = i * stride
             data = audio_bytes[beg:beg + stride]
             message = data
-            voices.put(message)
+            #voices.put(message)
+            await websocket.send(message)
             if i == chunk_num - 1:
                 is_speaking = False
                 message = json.dumps({"is_speaking": is_speaking})
-                voices.put(message)
-            # print("data_chunk: ", len(data_chunk))
-            # print(voices.qsize())
-            sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
+                #voices.put(message)
+                await websocket.send(message)
+ 
+            sleep_duration = 0.001 if args.mode == "offline" else 60 * args.chunk_size[1] / args.chunk_interval / 1000
+            
             await asyncio.sleep(sleep_duration)
+    # when all data sent, we need to close websocket
+    while not voices.empty():
+         await asyncio.sleep(1)
+    await asyncio.sleep(3)
+    # offline model need to wait for message recved
+    
+    if args.mode=="offline":
+      global offline_msg_done
+      while  not  offline_msg_done:
+         await asyncio.sleep(1)
+    
+    await websocket.close()
+
 
 async def ws_send():
     global voices
@@ -176,17 +195,19 @@
             await asyncio.sleep(0.005)
         await asyncio.sleep(0.005)
 
+ 
+             
 async def message(id):
-    global websocket
+    global websocket,voices,offline_msg_done
     text_print = ""
     text_print_2pass_online = ""
     text_print_2pass_offline = ""
-    while True:
-        try:
+    try:
+       while True:
+        
             meg = await websocket.recv()
             meg = json.loads(meg)
             wav_name = meg.get("wav_name", "demo")
-            # print(wav_name)
             text = meg["text"]
             if ibest_writer is not None:
                 ibest_writer["text"][wav_name] = text
@@ -201,6 +222,7 @@
                 text_print = text_print[-args.words_max_print:]
                 os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
+                offline_msg_done=True
             else:
                 if meg["mode"] == "2pass-online":
                     text_print_2pass_online += "{}".format(text)
@@ -213,10 +235,11 @@
                 os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
 
-        except Exception as e:
+    except Exception as e:
             print("Exception:", e)
-            traceback.print_exc()
-            exit(0)
+            #traceback.print_exc()
+            #await websocket.close()
+ 
 
 
 async def print_messge():
@@ -228,11 +251,18 @@
             print(meg)
         except Exception as e:
             print("Exception:", e)
-            traceback.print_exc()
+            #traceback.print_exc()
             exit(0)
 
 async def ws_client(id, chunk_begin, chunk_size):
-    global websocket
+  if args.audio_in is None:
+       chunk_begin=0
+       chunk_size=1
+  global websocket,voices,offline_msg_done
+ 
+  for i in range(chunk_begin,chunk_begin+chunk_size):
+    offline_msg_done=False
+    voices = Queue()
     if args.ssl == 1:
         ssl_context = ssl.SSLContext()
         ssl_context.check_hostname = False
@@ -242,14 +272,16 @@
         uri = "ws://{}:{}".format(args.host, args.port)
         ssl_context = None
     print("connect to", uri)
-    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context):
+    async with websockets.connect(uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context) as websocket:
         if args.audio_in is not None:
-            task = asyncio.create_task(record_from_scp(chunk_begin, chunk_size))
+            task = asyncio.create_task(record_from_scp(i, 1))
         else:
             task = asyncio.create_task(record_microphone())
         task2 = asyncio.create_task(ws_send())
-        task3 = asyncio.create_task(message(id))
+        task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
         await asyncio.gather(task, task2, task3)
+  exit(0)
+    
 
 def one_thread(id, chunk_begin, chunk_size):
     asyncio.get_event_loop().run_until_complete(ws_client(id, chunk_begin, chunk_size))
@@ -279,16 +311,16 @@
                     f'Not supported audio type: {audio_type}')
 
         total_len = len(wavs)
-        if total_len >= args.test_thread_num:
-            chunk_size = int(total_len / args.test_thread_num)
-            remain_wavs = total_len - chunk_size * args.test_thread_num
+        if total_len >= args.thread_num:
+            chunk_size = int(total_len / args.thread_num)
+            remain_wavs = total_len - chunk_size * args.thread_num
         else:
             chunk_size = 1
             remain_wavs = 0
 
         process_list = []
         chunk_begin = 0
-        for i in range(args.test_thread_num):
+        for i in range(args.thread_num):
             now_chunk_size = chunk_size
             if remain_wavs > 0:
                 now_chunk_size = chunk_size + 1

--
Gitblit v1.9.1