Yabin Li
2023-08-21 e0fa63765bfb4a36bde7047c2a6066ca5a80e90f
funasr/runtime/websocket/websocket-server.cpp
@@ -56,25 +56,37 @@
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(const std::vector<char>& buffer,
                                 websocketpp::connection_hdl& hdl,
                                 const nlohmann::json& msg) {
                                 websocketpp::lib::mutex& thread_lock,
                                 std::vector<std::vector<float>> &hotwords_embedding,
                                 std::string wav_name,
                                 std::string wav_format) {
  scoped_lock guard(thread_lock);
  try {
    int num_samples = buffer.size();  // the size of the buf
    if (!buffer.empty()) {
      // feed data to asr engine
      FUNASR_RESULT Result = FunOfflineInferBuffer(
          asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
    if (!buffer.empty() && hotwords_embedding.size() >0 ) {
      std::string asr_result;
      std::string stamp_res;
      try{
        FUNASR_RESULT Result = FunOfflineInferBuffer(
            asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, hotwords_embedding, 16000, wav_format);
      std::string asr_result =
          ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
      FunASRFreeResult(Result);
        asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
        stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
        FunASRFreeResult(Result);
      }catch (std::exception const& e) {
        LOG(ERROR) << e.what();
        return;
      }
      websocketpp::lib::error_code ec;
      nlohmann::json jsonresult;        // result json
      jsonresult["text"] = asr_result;  // put result in 'text'
      jsonresult["mode"] = "offline";
      jsonresult["wav_name"] = msg["wav_name"];
      if(stamp_res != ""){
        jsonresult["timestamp"] = stamp_res;
      }
      jsonresult["wav_name"] = wav_name;
      // send the json to client
      if (is_ssl) {
@@ -86,11 +98,6 @@
      }
      LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
      if (!isonline) {
        //  close the client if it is not online asr
        // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
        // fout.close();
      }
    }
  } catch (std::exception const& e) {
@@ -100,12 +107,11 @@
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);     // for threads safty
  check_and_clean_connection();  // remove closed connection
  std::shared_ptr<FUNASR_MESSAGE> data_msg =
      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
                                           // connection
  data_msg->samples = std::make_shared<std::vector<char>>();
  data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();
  data_msg->msg = nlohmann::json::parse("{}");
  data_msg->msg["wav_format"] = "pcm";
  data_map.emplace(hdl, data_msg);
@@ -114,37 +120,88 @@
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);
  data_map.erase(hdl);  // remove data vector when  connection is closed
  std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
    data_msg = it_data->second;
  } else {
    return;
  }
  unique_lock guard_decoder(*(data_msg->thread_lock));
  data_msg->msg["is_eof"]=true;
  guard_decoder.unlock();
  // data_map.erase(hdl);  // remove data vector when  connection is closed
  LOG(INFO) << "on_close, active connections: " << data_map.size();
}
// remove closed connection
void WebSocketServer::check_and_clean_connection() {
  std::vector<websocketpp::connection_hdl> to_remove;  // remove list
  auto iter = data_map.begin();
  while (iter != data_map.end()) {  // loop to find closed connection
    websocketpp::connection_hdl hdl = iter->first;
    if (is_ssl) {
      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
      if (con->get_state() != 1) {  // session::state::open ==1
        to_remove.push_back(hdl);
      }
    } else {
      server::connection_ptr con = server_->get_con_from_hdl(hdl);
      if (con->get_state() != 1) {  // session::state::open ==1
        to_remove.push_back(hdl);
      }
    }
    iter++;
void remove_hdl(
    websocketpp::connection_hdl hdl,
    std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
             std::owner_less<websocketpp::connection_hdl>>& data_map) {
  std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
    data_msg = it_data->second;
  } else {
    return;
  }
  for (auto hdl : to_remove) {
    data_map.erase(hdl);
    LOG(INFO)<< "remove one connection ";
  unique_lock guard_decoder(*(data_msg->thread_lock));
  if (data_msg->msg["is_eof"]==true) {
     data_map.erase(hdl);
    LOG(INFO) << "remove one connection";
  }
  guard_decoder.unlock();
}
void WebSocketServer::check_and_clean_connection() {
  while(true){
    std::this_thread::sleep_for(std::chrono::milliseconds(5000));
    std::vector<websocketpp::connection_hdl> to_remove;  // remove list
    auto iter = data_map.begin();
    while (iter != data_map.end()) {  // loop to find closed connection
      websocketpp::connection_hdl hdl = iter->first;
      try{
        if (is_ssl) {
          wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
          if (con->get_state() != 1) {  // session::state::open ==1
            to_remove.push_back(hdl);
          }
        } else {
          server::connection_ptr con = server_->get_con_from_hdl(hdl);
          if (con->get_state() != 1) {  // session::state::open ==1
            to_remove.push_back(hdl);
          }
        }
      }
      catch (std::exception const &e)
      {
        // if connection is close, we set is_eof = true
        std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
        auto it_data = data_map.find(hdl);
        if (it_data != data_map.end()) {
          data_msg = it_data->second;
        } else {
            continue;
        }
        unique_lock guard_decoder(*(data_msg->thread_lock));
        data_msg->msg["is_eof"]=true;
        guard_decoder.unlock();
        to_remove.push_back(hdl);
        LOG(INFO)<<"connection is closed: "<<e.what();
      }
      iter++;
    }
    for (auto hdl : to_remove) {
      remove_hdl(hdl, data_map);
      //LOG(INFO) << "remove one connection ";
    }
  }
}
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                                 message_ptr msg) {
  unique_lock lock(m_lock);
@@ -157,6 +214,7 @@
    msg_data = it_data->second;
  }
  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
  std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
  lock.unlock();
  if (sample_data_p == nullptr) {
@@ -165,7 +223,7 @@
  }
  const std::string& payload = msg->get_payload();  // get msg type
  unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
  switch (msg->get_opcode()) {
    case websocketpp::frame::opcode::text: {
      nlohmann::json jsonresult = nlohmann::json::parse(payload);
@@ -175,24 +233,42 @@
      if (jsonresult["wav_format"] != nullptr) {
        msg_data->msg["wav_format"] = jsonresult["wav_format"];
      }
      if(msg_data->hotwords_embedding == NULL){
        if (jsonresult["hotwords"] != nullptr) {
          msg_data->msg["hotwords"] = jsonresult["hotwords"];
          if (!msg_data->msg["hotwords"].empty()) {
            std::string hw = msg_data->msg["hotwords"];
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
          }
        }else{
            std::string hw = "";
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
        }
      }
      if (jsonresult["is_speaking"] == false ||
          jsonresult["is_finished"] == true) {
        LOG(INFO) << "client done";
        if (isonline) {
          // do_close(ws);
        } else {
          // add padding to the end of the wav data
          // std::vector<short> padding(static_cast<short>(0.3 * 16000));
          // sample_data_p->insert(sample_data_p->end(), padding.data(),
          //                       padding.data() + padding.size());
          // for offline, send all receive data to decoder engine
          asio::post(io_decoder_,
                     std::bind(&WebSocketServer::do_decoder, this,
                               std::move(*(sample_data_p.get())),
                               std::move(hdl), std::move(msg_data->msg)));
        }
        // add padding to the end of the wav data
        // std::vector<short> padding(static_cast<short>(0.3 * 16000));
        // sample_data_p->insert(sample_data_p->end(), padding.data(),
        //                       padding.data() + padding.size());
        // for offline, send all receive data to decoder engine
        std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
        asio::post(io_decoder_,
                    std::bind(&WebSocketServer::do_decoder, this,
                              std::move(*(sample_data_p.get())),
                              std::move(hdl),
                              std::ref(*thread_lock_p),
                              std::move(hotwords_embedding_),
                              msg_data->msg["wav_name"],
                              msg_data->msg["wav_format"]));
      }
      break;
    }
@@ -200,19 +276,15 @@
      // recived binary data
      const auto* pcm_data = static_cast<const char*>(payload.data());
      int32_t num_samples = payload.size();
      //LOG(INFO) << "recv binary num_samples " << num_samples;
      if (isonline) {
        // if online TODO(zhaoming) still not done
        std::vector<char> s(pcm_data, pcm_data + num_samples);
        asio::post(io_decoder_,
                   std::bind(&WebSocketServer::do_decoder, this, std::move(s),
                             std::move(hdl), std::move(msg_data->msg)));
        // TODO
      } else {
        // for offline, we add receive data to end of the sample data vector
        sample_data_p->insert(sample_data_p->end(), pcm_data,
                              pcm_data + num_samples);
      }
      break;
    }
    default:
@@ -228,6 +300,11 @@
    asr_hanlde = FunOfflineInit(model_path, thread_num);
    LOG(INFO) << "model successfully inited";
    LOG(INFO) << "initAsr run check_and_clean_connection";
    std::thread clean_thread(&WebSocketServer::check_and_clean_connection,this);
    clean_thread.detach();
    LOG(INFO) << "initAsr run check_and_clean_connection finished";
  } catch (const std::exception& e) {
    LOG(INFO) << e.what();