游雁
2023-10-23 26a2a232a94c4a729733d83e8175a16e3f8db481
funasr/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -15,7 +15,9 @@
#include <thread>
#include <utility>
#include <vector>
#include <chrono>
extern std::string hotwords;
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                         websocketpp::connection_hdl hdl,
                                         std::string& s_certfile,
@@ -354,7 +356,14 @@
  unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
  switch (msg->get_opcode()) {
    case websocketpp::frame::opcode::text: {
      nlohmann::json jsonresult = nlohmann::json::parse(payload);
      nlohmann::json jsonresult;
      try{
        jsonresult = nlohmann::json::parse(payload);
      }catch (std::exception const &e)
      {
        LOG(ERROR)<<e.what();
        break;
      }
      if (jsonresult.contains("wav_name")) {
        msg_data->msg["wav_name"] = jsonresult["wav_name"];
@@ -370,17 +379,26 @@
          msg_data->msg["hotwords"] = jsonresult["hotwords"];
          if (!msg_data->msg["hotwords"].empty()) {
            std::string hw = msg_data->msg["hotwords"];
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
            hw = hw + " " + hotwords;
            LOG(INFO) << "hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
          }
        }else{
        } else {
          if (hotwords.empty()) {
            std::string hw = "";
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
          }else {
            std::string hw = hotwords;
            LOG(INFO) << "hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
          }
        }
      }
      if (jsonresult.contains("audio_fs")) {