雾聪
2023-08-10 6a315917163d68a5e48a40809a420f44b2ec5a15
Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
3个文件已修改
16 ■■■■■ 已修改文件
funasr/models/e2e_asr_contextual_paraformer.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocket-server-2pass.cpp 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocket-server-2pass.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_contextual_paraformer.py
@@ -350,6 +350,7 @@
            else:
                hw_embed = self.bias_embed(hw_list_pad)
            hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        else:
            hw_lengths = [len(i) for i in hw_list]
            hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -81,6 +81,10 @@
    FUNASR_HANDLE& tpass_online_handle) {
  // lock for each connection
  scoped_lock guard(thread_lock);
  if(!tpass_online_handle){
      LOG(INFO) << "tpass_online_handle  is free, return";
      return;
  }
  FUNASR_RESULT Result = nullptr;
  int asr_mode_ = 2;
  if (msg.contains("mode")) {
@@ -180,7 +184,7 @@
      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
                                           // connection
  data_msg->samples = std::make_shared<std::vector<char>>();
  data_msg->thread_lock = new websocketpp::lib::mutex();
  data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();
  data_msg->msg = nlohmann::json::parse("{}");
  data_msg->msg["wav_format"] = "pcm";
@@ -199,7 +203,7 @@
    websocketpp::connection_hdl hdl,
    std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
             std::owner_less<websocketpp::connection_hdl>>& data_map) {
  // return;
  std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
@@ -215,8 +219,9 @@
    FunTpassOnlineUninit(data_msg->tpass_online_handle);
    data_msg->tpass_online_handle = nullptr;
  }
  guard_decoder.unlock();
  delete data_msg->thread_lock;
  data_map.erase(hdl);  // remove data vector when  connection is closed
}
@@ -270,7 +275,7 @@
  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
  std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
      msg_data->punc_cache;
  websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
  std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
  lock.unlock();
funasr/runtime/websocket/websocket-server-2pass.h
@@ -54,7 +54,7 @@
  nlohmann::json msg;
  std::shared_ptr<std::vector<char>> samples;
  std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache;
  websocketpp::lib::mutex* thread_lock; // lock for each connection
  std::shared_ptr<websocketpp::lib::mutex> thread_lock; // lock for each connection
  FUNASR_HANDLE tpass_online_handle=NULL;
  std::string online_res = "";
  std::string tpass_res = "";