雾聪
2023-08-09 d4a021b45b18fc24d9893ef8a3cdf3adba380490
funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -53,7 +53,7 @@
  return ctx;
}
nlohmann::json handle_result(FUNASR_RESULT result, nlohmann::json msg) {
nlohmann::json handle_result(FUNASR_RESULT result) {
    websocketpp::lib::error_code ec;
    nlohmann::json jsonresult;
@@ -72,10 +72,6 @@
      jsonresult["mode"] = "2pass-offline";    
    }
    if (msg.contains("wav_name")) {
      jsonresult["wav_name"] = msg["wav_name"];
    }
    return jsonresult;
}
// feed buffer to asr engine for decoder
@@ -83,7 +79,7 @@
    std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
    nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
    websocketpp::lib::mutex& thread_lock, bool& is_final,
    FUNASR_HANDLE& tpass_online_handle) {
    std::string wav_name, FUNASR_HANDLE& tpass_online_handle) {
 
  // lock for each connection
  scoped_lock guard(thread_lock);
@@ -123,7 +119,8 @@
      if (Result) {
        websocketpp::lib::error_code ec;
        nlohmann::json jsonresult =
            handle_result(Result, msg);
            handle_result(Result);
        jsonresult["wav_name"] = wav_name;
        jsonresult["is_final"] = false;
        if(jsonresult["text"] != "") {
          if (is_ssl) {
@@ -154,7 +151,8 @@
      if (Result) {
        websocketpp::lib::error_code ec;
        nlohmann::json jsonresult =
            handle_result(Result, msg);
            handle_result(Result);
        jsonresult["wav_name"] = wav_name;
        jsonresult["is_final"] = true;
        if (is_ssl) {
          wss_server_->send(hdl, jsonresult.dump(),
@@ -285,6 +283,7 @@
      if (jsonresult.contains("chunk_size")){
        if(msg_data->tpass_online_handle == NULL){
          std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
          LOG(INFO) << "----------------FunTpassOnlineInit----------------------";
          FUNASR_HANDLE tpass_online_handle =
              FunTpassOnlineInit(tpass_handle, chunk_size_vec);
          msg_data->tpass_online_handle = tpass_online_handle;
@@ -303,6 +302,7 @@
                      std::move(*(sample_data_p.get())), std::move(hdl),
                      std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
                      std::ref(*thread_lock_p), std::move(true),
                      msg_data->msg["wav_name"],
                      std::ref(msg_data->tpass_online_handle)));
      }
      break;
@@ -333,6 +333,7 @@
                                  std::ref(msg_data->msg),
                                  std::ref(*(punc_cache_p.get())),
                                  std::ref(*thread_lock_p), std::move(false),
                                  msg_data->msg["wav_name"],
                                  std::ref(msg_data->tpass_online_handle)));
        }
      } else {