From f03a604204bbe0c79e53b01237a37e88683938c6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 五月 2023 00:17:32 +0800
Subject: [PATCH] Merge pull request #505 from zhaomingwork/cpp-python-websocket-compatible

---
 funasr/runtime/websocket/websocketsrv.cpp           |   60 ++++++++++----
 funasr/runtime/python/websocket/ws_client.py        |   56 ++++++++++---
 funasr/runtime/websocket/websocketsrv.h             |   11 ++
 funasr/runtime/python/websocket/ws_server_online.py |   41 ++++++---
 funasr/runtime/websocket/websocketclient.cpp        |   20 ++++
 5 files changed, 139 insertions(+), 49 deletions(-)

diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index a4a6d9f..7ae44df 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -84,18 +84,20 @@
                     rate=RATE,
                     input=True,
                     frames_per_buffer=CHUNK)
+
+    message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+    voices.put(message)
     is_speaking = True
     while True:
 
         data = stream.read(CHUNK)
-        data = data.decode('ISO-8859-1')
-        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
+        message = data  
         
         voices.put(message)
 
         await asyncio.sleep(0.005)
 
-async def record_from_scp():
+async def record_from_scp(chunk_begin,chunk_size):
     import wave
     global voices
     is_finished = False
@@ -104,6 +106,8 @@
         wavs = f_scp.readlines()
     else:
         wavs = [args.audio_in]
+    if chunk_size>0:
+        wavs=wavs[chunk_begin:chunk_begin+chunk_size]
     for wav in wavs:
         wav_splits = wav.strip().split()
         wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
@@ -122,15 +126,21 @@
         stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
         chunk_num = (len(audio_bytes)-1)//stride + 1
         # print(stride)
+        
+        # send first time
+        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+        voices.put(message)
         is_speaking = True
         for i in range(chunk_num):
-            if i == chunk_num-1:
-                is_speaking = False
+
             beg = i*stride
             data = audio_bytes[beg:beg+stride]
-            data = data.decode('ISO-8859-1')
-            message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished, "wav_name": wav_name})
+            message = data  
             voices.put(message)
+            if i == chunk_num-1:
+                is_speaking = False
+                message = json.dumps({"is_speaking": is_speaking})
+                voices.put(message)
             # print("data_chunk: ", len(data_chunk))
             # print(voices.qsize())
             sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
@@ -213,27 +223,47 @@
             traceback.print_exc()
             exit(0)
 
-async def ws_client(id):
+async def ws_client(id,chunk_begin,chunk_size):
     global websocket
     uri = "ws://{}:{}".format(args.host, args.port)
     async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
         if args.audio_in is not None:
-            task = asyncio.create_task(record_from_scp())
+            task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
         else:
             task = asyncio.create_task(record_microphone())
         task2 = asyncio.create_task(ws_send())
         task3 = asyncio.create_task(message(id))
         await asyncio.gather(task, task2, task3)
 
-def one_thread(id):
-   asyncio.get_event_loop().run_until_complete(ws_client(id))
+def one_thread(id,chunk_begin,chunk_size):
+   asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
    asyncio.get_event_loop().run_forever()
 
 
 if __name__ == '__main__':
+    # calculate the number of wavs for each preocess
+    if args.audio_in.endswith(".scp"):
+        f_scp = open(args.audio_in)
+        wavs = f_scp.readlines()
+    else:
+        wavs = [args.audio_in]
+    total_len=len(wavs)
+    if total_len>=args.test_thread_num:
+         chunk_size=int((total_len)/args.test_thread_num)
+         remain_wavs=total_len-chunk_size*args.test_thread_num
+    else:
+         chunk_size=0
+    
     process_list = []
-    for i in range(args.test_thread_num):   
-        p = Process(target=one_thread,args=(i,))
+    chunk_begin=0
+    for i in range(args.test_thread_num):
+        now_chunk_size= chunk_size
+        if remain_wavs>0:
+            now_chunk_size=chunk_size+1
+            remain_wavs=remain_wavs-1
+        # process i handle wavs at chunk_begin and size of now_chunk_size
+        p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
+        chunk_begin=chunk_begin+now_chunk_size
         p.start()
         process_list.append(p)
 
diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py
index 3c0fb16..44edf98 100644
--- a/funasr/runtime/python/websocket/ws_server_online.py
+++ b/funasr/runtime/python/websocket/ws_server_online.py
@@ -41,25 +41,37 @@
     global websocket_users
     websocket_users.add(websocket)
     websocket.param_dict_asr_online = {"cache": dict()}
-
+    print("new user connected",flush=True)
     try:
         async for message in websocket:
-            message = json.loads(message)
-            is_finished = message["is_finished"]
-            if not is_finished:
-                audio = bytes(message['audio'], 'ISO-8859-1')
-
-                is_speaking = message["is_speaking"]
-                websocket.param_dict_asr_online["is_final"] = not is_speaking
-                websocket.wav_name = message.get("wav_name", "demo")
-                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
-                
-                frames_asr_online.append(audio)
-                if len(frames_asr_online) % message["chunk_interval"] == 0 or not is_speaking:
+            
+ 
+            if isinstance(message,str):
+              messagejson = json.loads(message)
+               
+              if "is_speaking" in messagejson:
+                  websocket.is_speaking = messagejson["is_speaking"]  
+                  websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+              if "is_finished" in messagejson:
+                  websocket.is_speaking = False
+                  websocket.param_dict_asr_online["is_final"] = True
+              if "chunk_interval" in messagejson:
+                  websocket.chunk_interval=messagejson["chunk_interval"]
+              if "wav_name" in messagejson:
+                  websocket.wav_name = messagejson.get("wav_name", "demo")
+              if "chunk_size" in messagejson:
+                  websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+            # if has bytes in buffer or message is bytes
+            if len(frames_asr_online)>0 or not isinstance(message,str):
+               if not isinstance(message,str):
+                 frames_asr_online.append(message)
+               if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
                     audio_in = b"".join(frames_asr_online)
+                    if not websocket.is_speaking:
+                       #padding 0.5s at end gurantee that asr engine can fire out last word
+                       audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
                     await async_asr_online(websocket,audio_in)
                     frames_asr_online = []
-
 
      
     except websockets.ConnectionClosed:
@@ -69,6 +81,7 @@
         print("InvalidState...")
     except Exception as e:
         print("Exception:", e)
+
  
 async def async_asr_online(websocket,audio_in):
             if len(audio_in) > 0:
diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
index 3ab4e99..078fc5a 100644
--- a/funasr/runtime/websocket/websocketclient.cpp
+++ b/funasr/runtime/websocket/websocketclient.cpp
@@ -13,6 +13,7 @@
 #include <websocketpp/config/asio_no_tls_client.hpp>
 
 #include "audio.h"
+#include "nlohmann/json.hpp"
 
 /**
  * Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -156,6 +157,19 @@
       }
     }
     websocketpp::lib::error_code ec;
+
+    nlohmann::json jsonbegin;
+    nlohmann::json chunk_size = nlohmann::json::array();
+    chunk_size.push_back(5);
+    chunk_size.push_back(0);
+    chunk_size.push_back(5);
+    jsonbegin["chunk_size"] = chunk_size;
+    jsonbegin["chunk_interval"] = 10;
+    jsonbegin["wav_name"] = "damo";
+    jsonbegin["is_speaking"] = true;
+    m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+                  ec);
+
     // fetch wav data use asr engine api
     while (audio.Fetch(buff, len, flag) > 0) {
       short iArray[len];
@@ -181,8 +195,10 @@
 
       wait_a_bit();
     }
-
-    m_client.send(m_hdl, "Done", websocketpp::frame::opcode::text, ec);
+    nlohmann::json jsonresult;
+    jsonresult["is_speaking"] = false;
+    m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+                  ec);
     wait_a_bit();
   }
 
diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocketsrv.cpp
index 1a6adbf..598ad3d 100644
--- a/funasr/runtime/websocket/websocketsrv.cpp
+++ b/funasr/runtime/websocket/websocketsrv.cpp
@@ -34,6 +34,14 @@
       websocketpp::lib::error_code ec;
       nlohmann::json jsonresult;        // result json
       jsonresult["text"] = asr_result;  // put result in 'text'
+      jsonresult["mode"] = "offline";
+      std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+      auto it_data = data_map.find(hdl);
+      if (it_data != data_map.end()) {
+        msg_data = it_data->second;
+      }
+
+      jsonresult["wav_name"] = msg_data->msg["wav_name"];
 
       // send the json to client
       server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
@@ -43,7 +51,7 @@
                 << ",result json=" << jsonresult.dump() << std::endl;
       if (!isonline) {
         //  close the client if it is not online asr
-        server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
+        // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
         // fout.close();
       }
     }
@@ -56,25 +64,28 @@
 void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
   scoped_lock guard(m_lock);     // for threads safty
   check_and_clean_connection();  // remove closed connection
-  sample_map.emplace(
-      hdl, std::make_shared<std::vector<char>>());  // put a new data vector for
-                                                    // new connection
-  std::cout << "on_open, active connections: " << sample_map.size()
-            << std::endl;
+
+  std::shared_ptr<FUNASR_MESSAGE> data_msg =
+      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
+                                           // connection
+  data_msg->samples = std::make_shared<std::vector<char>>();
+  data_msg->msg = nlohmann::json::parse("{}");
+  data_map.emplace(hdl, data_msg);
+  std::cout << "on_open, active connections: " << data_map.size() << std::endl;
 }
 
 void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
   scoped_lock guard(m_lock);
-  sample_map.erase(hdl);  // remove data vector when  connection is closed
-  std::cout << "on_close, active connections: " << sample_map.size()
-            << std::endl;
+  data_map.erase(hdl);  // remove data vector when  connection is closed
+
+  std::cout << "on_close, active connections: " << data_map.size() << std::endl;
 }
 
 // remove closed connection
 void WebSocketServer::check_and_clean_connection() {
   std::vector<websocketpp::connection_hdl> to_remove;  // remove list
-  auto iter = sample_map.begin();
-  while (iter != sample_map.end()) {  // loop to find closed connection
+  auto iter = data_map.begin();
+  while (iter != data_map.end()) {  // loop to find closed connection
     websocketpp::connection_hdl hdl = iter->first;
     server::connection_ptr con = server_->get_con_from_hdl(hdl);
     if (con->get_state() != 1) {  // session::state::open ==1
@@ -83,7 +94,7 @@
     iter++;
   }
   for (auto hdl : to_remove) {
-    sample_map.erase(hdl);
+    data_map.erase(hdl);
     std::cout << "remove one connection " << std::endl;
   }
 }
@@ -91,12 +102,15 @@
                                  message_ptr msg) {
   unique_lock lock(m_lock);
   // find the sample data vector according to one connection
-  std::shared_ptr<std::vector<char>> sample_data_p = nullptr;
 
-  auto it = sample_map.find(hdl);
-  if (it != sample_map.end()) {
-    sample_data_p = it->second;
+  std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+
+  auto it_data = data_map.find(hdl);
+  if (it_data != data_map.end()) {
+    msg_data = it_data->second;
   }
+  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
+
   lock.unlock();
   if (sample_data_p == nullptr) {
     std::cout << "error when fetch sample data vector" << std::endl;
@@ -106,13 +120,22 @@
   const std::string& payload = msg->get_payload();  // get msg type
 
   switch (msg->get_opcode()) {
-    case websocketpp::frame::opcode::text:
-      if (payload == "Done") {
+    case websocketpp::frame::opcode::text: {
+      nlohmann::json jsonresult = nlohmann::json::parse(payload);
+      if (jsonresult["wav_name"] != nullptr) {
+        msg_data->msg["wav_name"] = jsonresult["wav_name"];
+      }
+      if (jsonresult["is_speaking"] == false ||
+          jsonresult["is_finished"] == true) {
         std::cout << "client done" << std::endl;
 
         if (isonline) {
           // do_close(ws);
         } else {
+          // add padding to the end of the wav data
+          std::vector<short> padding(static_cast<short>(0.3 * 16000));
+          sample_data_p->insert(sample_data_p->end(), padding.data(),
+                                padding.data() + padding.size());
           // for offline, send all receive data to decoder engine
           asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
                                             std::move(*(sample_data_p.get())),
@@ -120,6 +143,7 @@
         }
       }
       break;
+    }
     case websocketpp::frame::opcode::binary: {
       // recived binary data
       const auto* pcm_data = static_cast<const char*>(payload.data());
diff --git a/funasr/runtime/websocket/websocketsrv.h b/funasr/runtime/websocket/websocketsrv.h
index e484724..1899f57 100644
--- a/funasr/runtime/websocket/websocketsrv.h
+++ b/funasr/runtime/websocket/websocketsrv.h
@@ -46,6 +46,11 @@
   float snippet_time;
 } FUNASR_RECOG_RESULT;
 
+typedef struct {
+  nlohmann::json msg;
+  std::shared_ptr<std::vector<char>> samples;
+} FUNASR_MESSAGE;
+
 class WebSocketServer {
  public:
   WebSocketServer(asio::io_context& io_decoder, server* server_)
@@ -84,9 +89,11 @@
 
   // use map to keep the received samples data from one connection in offline
   // engine. if for online engline, a data struct is needed(TODO)
-  std::map<websocketpp::connection_hdl, std::shared_ptr<std::vector<char>>,
+ 
+
+  std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
            std::owner_less<websocketpp::connection_hdl>>
-      sample_map;
+      data_map;
   websocketpp::lib::mutex m_lock;  // mutex for sample_map
 };
 

--
Gitblit v1.9.1