From ffb05b9ae7eccc47416e9e7fae9dea54d400a245 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 10 八月 2023 19:05:51 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/runtime/websocket/websocket-server-2pass.cpp |   66 +++++++++++++++++++++-----------
 1 files changed, 43 insertions(+), 23 deletions(-)

diff --git a/funasr/runtime/websocket/websocket-server-2pass.cpp b/funasr/runtime/websocket/websocket-server-2pass.cpp
index 75312a3..50f0edc 100644
--- a/funasr/runtime/websocket/websocket-server-2pass.cpp
+++ b/funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -81,6 +81,10 @@
     FUNASR_HANDLE& tpass_online_handle) {
   // lock for each connection
   scoped_lock guard(thread_lock);
+  if(!tpass_online_handle){
+	  LOG(INFO) << "tpass_online_handle  is free, return";
+	  return;
+  }
   FUNASR_RESULT Result = nullptr;
   int asr_mode_ = 2;
   if (msg.contains("mode")) {
@@ -148,8 +152,10 @@
       } catch (std::exception const& e) {
         LOG(ERROR) << e.what();
       }
-      for (auto& vec : punc_cache) {
-        vec.clear();
+      if(punc_cache.size()>0){
+        for (auto& vec : punc_cache) {
+          vec.clear();
+        }
       }
       if (Result) {
         websocketpp::lib::error_code ec;
@@ -180,7 +186,7 @@
       std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
                                            // connection
   data_msg->samples = std::make_shared<std::vector<char>>();
-  data_msg->thread_lock = new websocketpp::lib::mutex();
+  data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();  
 
   data_msg->msg = nlohmann::json::parse("{}");
   data_msg->msg["wav_format"] = "pcm";
@@ -199,7 +205,7 @@
     websocketpp::connection_hdl hdl,
     std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
              std::owner_less<websocketpp::connection_hdl>>& data_map) {
-  // return;
+ 
   std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
   auto it_data = data_map.find(hdl);
   if (it_data != data_map.end()) {
@@ -215,8 +221,9 @@
     FunTpassOnlineUninit(data_msg->tpass_online_handle);
     data_msg->tpass_online_handle = nullptr;
   }
+  
+ 
   guard_decoder.unlock();
-  delete data_msg->thread_lock;
   data_map.erase(hdl);  // remove data vector when  connection is closed
 }
 
@@ -270,7 +277,7 @@
   std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
   std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
       msg_data->punc_cache;
-  websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
+  std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
 
   lock.unlock();
 
@@ -315,14 +322,20 @@
         LOG(INFO) << "client done";
 
         // if it is in final message, post the sample_data to decode
-        asio::post(
-            io_decoder_,
-            std::bind(&WebSocketServer::do_decoder, this,
-                      std::move(*(sample_data_p.get())), std::move(hdl),
-                      std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
-                      std::ref(*thread_lock_p), std::move(true),
-                      msg_data->msg["wav_name"],
-                      std::ref(msg_data->tpass_online_handle)));
+        try{
+          asio::post(
+              io_decoder_,
+              std::bind(&WebSocketServer::do_decoder, this,
+                        std::move(*(sample_data_p.get())), std::move(hdl),
+                        std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
+                        std::ref(*thread_lock_p), std::move(true),
+                        msg_data->msg["wav_name"],
+                        std::ref(msg_data->tpass_online_handle)));
+        }
+        catch (std::exception const &e)
+        {
+            LOG(ERROR)<<e.what();
+        }
       }
       break;
     }
@@ -346,15 +359,22 @@
           // keep remain in sample_data
           sample_data_p->erase(sample_data_p->begin(),
                                sample_data_p->begin() + chunksize * setpsize);
-          // post to decode
-          asio::post(io_decoder_,
-                     std::bind(&WebSocketServer::do_decoder, this,
-                               std::move(subvector), std::move(hdl),
-                               std::ref(msg_data->msg),
-                               std::ref(*(punc_cache_p.get())),
-                               std::ref(*thread_lock_p), std::move(false),
-                               msg_data->msg["wav_name"],
-                               std::ref(msg_data->tpass_online_handle)));
+
+          try{
+            // post to decode
+            asio::post(io_decoder_,
+                      std::bind(&WebSocketServer::do_decoder, this,
+                                std::move(subvector), std::move(hdl),
+                                std::ref(msg_data->msg),
+                                std::ref(*(punc_cache_p.get())),
+                                std::ref(*thread_lock_p), std::move(false),
+                                msg_data->msg["wav_name"],
+                                std::ref(msg_data->tpass_online_handle)));
+          }
+          catch (std::exception const &e)
+          {
+              LOG(ERROR)<<e.what();
+          }
         }
       } else {
         sample_data_p->insert(sample_data_p->end(), pcm_data,

--
Gitblit v1.9.1