From ec0e75ea8bf444d0ebff71a1b37f1a9b10f071a8 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 09 八月 2023 20:23:57 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/runtime/websocket/websocket-server-2pass.cpp |   30 ++++++++++++------------------
 1 files changed, 12 insertions(+), 18 deletions(-)

diff --git a/funasr/runtime/websocket/websocket-server-2pass.cpp b/funasr/runtime/websocket/websocket-server-2pass.cpp
index 8833b0b..7df6341 100644
--- a/funasr/runtime/websocket/websocket-server-2pass.cpp
+++ b/funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -53,30 +53,23 @@
   return ctx;
 }
 
-nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
-                             std::string& tpass_res, nlohmann::json msg) {
+nlohmann::json handle_result(FUNASR_RESULT result) {
 
     websocketpp::lib::error_code ec;
     nlohmann::json jsonresult;
     jsonresult["text"]="";
 
     std::string tmp_online_msg = FunASRGetResult(result, 0);
-    online_res += tmp_online_msg;
     if (tmp_online_msg != "") {
       LOG(INFO) << "online_res :" << tmp_online_msg;
       jsonresult["text"] = tmp_online_msg; 
       jsonresult["mode"] = "2pass-online";
     }
     std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
-    tpass_res += tmp_tpass_msg;
     if (tmp_tpass_msg != "") {
       LOG(INFO) << "offline results : " << tmp_tpass_msg;
       jsonresult["text"] = tmp_tpass_msg; 
       jsonresult["mode"] = "2pass-offline";    
-    }
-
-    if (msg.contains("wav_name")) {
-      jsonresult["wav_name"] = msg["wav_name"];
     }
 
     return jsonresult;
@@ -86,8 +79,7 @@
     std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
     nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
     websocketpp::lib::mutex& thread_lock, bool& is_final,
-    FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
-    std::string& tpass_res) {
+    std::string wav_name, FUNASR_HANDLE& tpass_online_handle) {
  
   // lock for each connection
   scoped_lock guard(thread_lock);
@@ -127,7 +119,8 @@
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+            handle_result(Result);
+        jsonresult["wav_name"] = wav_name;
         jsonresult["is_final"] = false;
         if(jsonresult["text"] != "") {
           if (is_ssl) {
@@ -158,7 +151,8 @@
       if (Result) {
         websocketpp::lib::error_code ec;
         nlohmann::json jsonresult =
-            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+            handle_result(Result);
+        jsonresult["wav_name"] = wav_name;
         jsonresult["is_final"] = true;
         if (is_ssl) {
           wss_server_->send(hdl, jsonresult.dump(),
@@ -212,6 +206,7 @@
     return;
   }
   scoped_lock guard_decoder(*(data_msg->thread_lock));  //wait for do_decoder finished and avoid access freed tpass_online_handle 
+  LOG(INFO) << "----------------FunTpassOnlineUninit----------------------";
   FunTpassOnlineUninit(data_msg->tpass_online_handle);
   data_map.erase(hdl);  // remove data vector when  connection is closed
   LOG(INFO) << "on_close, active connections: "<< data_map.size();
@@ -288,6 +283,7 @@
       if (jsonresult.contains("chunk_size")){
         if(msg_data->tpass_online_handle == NULL){
           std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
+          LOG(INFO) << "----------------FunTpassOnlineInit----------------------";
           FUNASR_HANDLE tpass_online_handle =
               FunTpassOnlineInit(tpass_handle, chunk_size_vec);
           msg_data->tpass_online_handle = tpass_online_handle;
@@ -306,9 +302,8 @@
                       std::move(*(sample_data_p.get())), std::move(hdl),
                       std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
                       std::ref(*thread_lock_p), std::move(true),
-                      std::ref(msg_data->tpass_online_handle),
-                      std::ref(msg_data->online_res),
-                      std::ref(msg_data->tpass_res)));
+                      msg_data->msg["wav_name"],
+                      std::ref(msg_data->tpass_online_handle)));
       }
       break;
     }
@@ -338,9 +333,8 @@
                                   std::ref(msg_data->msg),
                                   std::ref(*(punc_cache_p.get())),
                                   std::ref(*thread_lock_p), std::move(false),
-                                  std::ref(msg_data->tpass_online_handle),
-                                  std::ref(msg_data->online_res),
-                                  std::ref(msg_data->tpass_res)));
+                                  msg_data->msg["wav_name"],
+                                  std::ref(msg_data->tpass_online_handle)));
         }
       } else {
         sample_data_p->insert(sample_data_p->end(), pcm_data,

--
Gitblit v1.9.1