雾聪
2023-08-09 08dc101ee2044a8811148e4b40ba119b24861b0f
funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -53,22 +53,19 @@
  return ctx;
}
nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
                             std::string& tpass_res, nlohmann::json msg) {
nlohmann::json handle_result(FUNASR_RESULT result, nlohmann::json msg) {
    websocketpp::lib::error_code ec;
    nlohmann::json jsonresult;
    jsonresult["text"]="";
    std::string tmp_online_msg = FunASRGetResult(result, 0);
    online_res += tmp_online_msg;
    if (tmp_online_msg != "") {
      LOG(INFO) << "online_res :" << tmp_online_msg;
      jsonresult["text"] = tmp_online_msg; 
      jsonresult["mode"] = "2pass-online";
    }
    std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
    tpass_res += tmp_tpass_msg;
    if (tmp_tpass_msg != "") {
      LOG(INFO) << "offline results : " << tmp_tpass_msg;
      jsonresult["text"] = tmp_tpass_msg; 
@@ -86,8 +83,7 @@
    std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
    nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
    websocketpp::lib::mutex& thread_lock, bool& is_final,
    FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
    std::string& tpass_res) {
    FUNASR_HANDLE& tpass_online_handle) {
 
  // lock for each connection
  scoped_lock guard(thread_lock);
@@ -127,7 +123,7 @@
      if (Result) {
        websocketpp::lib::error_code ec;
        nlohmann::json jsonresult =
            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
            handle_result(Result, msg["wav_name"]);
        jsonresult["is_final"] = false;
        if(jsonresult["text"] != "") {
          if (is_ssl) {
@@ -158,7 +154,7 @@
      if (Result) {
        websocketpp::lib::error_code ec;
        nlohmann::json jsonresult =
            handle_result(Result, online_res, tpass_res, msg["wav_name"]);
            handle_result(Result, msg["wav_name"]);
        jsonresult["is_final"] = true;
        if (is_ssl) {
          wss_server_->send(hdl, jsonresult.dump(),
@@ -306,9 +302,7 @@
                      std::move(*(sample_data_p.get())), std::move(hdl),
                      std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
                      std::ref(*thread_lock_p), std::move(true),
                      std::ref(msg_data->tpass_online_handle),
                      std::ref(msg_data->online_res),
                      std::ref(msg_data->tpass_res)));
                      std::ref(msg_data->tpass_online_handle)));
      }
      break;
    }
@@ -338,9 +332,7 @@
                                  std::ref(msg_data->msg),
                                  std::ref(*(punc_cache_p.get())),
                                  std::ref(*thread_lock_p), std::move(false),
                                  std::ref(msg_data->tpass_online_handle),
                                  std::ref(msg_data->online_res),
                                  std::ref(msg_data->tpass_res)));
                                  std::ref(msg_data->tpass_online_handle)));
        }
      } else {
        sample_data_p->insert(sample_data_p->end(), pcm_data,