kongdeqiang
7 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
runtime/websocket/bin/websocket-server-2pass.cpp
@@ -111,7 +111,9 @@
    int audio_fs,
    std::string wav_format,
    FUNASR_HANDLE& tpass_online_handle,
    FUNASR_DEC_HANDLE& decoder_handle) {
    FUNASR_DEC_HANDLE& decoder_handle,
    std::string svs_lang,
    bool sys_itn) {
  // lock for each connection
  if(!tpass_online_handle){
    scoped_lock guard(thread_lock);
@@ -140,7 +142,8 @@
                                       subvector.data(), subvector.size(),
                                       punc_cache, false, audio_fs,
                                       wav_format, (ASR_TYPE)asr_mode_,
                                       hotwords_embedding, itn, decoder_handle);
                                       hotwords_embedding, itn, decoder_handle,
                                       svs_lang, sys_itn);
        } else {
          scoped_lock guard(thread_lock);
@@ -177,7 +180,8 @@
                                       buffer.data(), buffer.size(), punc_cache,
                                       is_final, audio_fs,
                                       wav_format, (ASR_TYPE)asr_mode_,
                                       hotwords_embedding, itn, decoder_handle);
                                       hotwords_embedding, itn, decoder_handle,
                                       svs_lang, sys_itn);
        } else {
          scoped_lock guard(thread_lock);
          msg["access_num"]=(int)msg["access_num"]-1;    
@@ -211,7 +215,7 @@
        if(wav_format != "pcm" && wav_format != "PCM"){
          websocketpp::lib::error_code ec;
          nlohmann::json jsonresult;
          jsonresult["text"] = "ERROR. Real-time transcription service ONLY SUPPORT wav_format pcm.";
          jsonresult["text"] = "ERROR. Real-time transcription service ONLY SUPPORT PCM stream.";
          jsonresult["wav_name"] = wav_name;
          jsonresult["is_final"] = true;
          if (is_ssl) {
@@ -250,6 +254,8 @@
    data_msg->msg["audio_fs"] = 16000; // default is 16k
    data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
    data_msg->msg["is_eof"]=false; // if this connection is closed
    data_msg->msg["svs_lang"]="auto";
    data_msg->msg["svs_itn"]=true;
    FUNASR_DEC_HANDLE decoder_handle =
      FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
    data_msg->decoder_handle = decoder_handle;
@@ -409,7 +415,7 @@
      }
      // hotwords: fst/nn
      if(msg_data->hotwords_embedding == NULL){
      if(msg_data->hotwords_embedding == nullptr){
        std::unordered_map<std::string, int> merged_hws_map;
        std::string nn_hotwords = "";
@@ -458,7 +464,7 @@
        msg_data->msg["audio_fs"] = jsonresult["audio_fs"];
      }
      if (jsonresult.contains("chunk_size")) {
        if (msg_data->tpass_online_handle == NULL) {
        if (msg_data->tpass_online_handle == nullptr) {
          std::vector<int> chunk_size_vec =
              jsonresult["chunk_size"].get<std::vector<int>>();
          // check chunk_size_vec
@@ -475,12 +481,18 @@
      if (jsonresult.contains("itn")) {
        msg_data->msg["itn"] = jsonresult["itn"];
      }
      if (jsonresult.contains("svs_lang")) {
        msg_data->msg["svs_lang"] = jsonresult["svs_lang"];
      }
      if (jsonresult.contains("svs_itn")) {
        msg_data->msg["svs_itn"] = jsonresult["svs_itn"];
      }
      LOG(INFO) << "jsonresult=" << jsonresult
                << ", msg_data->msg=" << msg_data->msg;
      if ((jsonresult["is_speaking"] == false ||
          jsonresult["is_finished"] == true) && 
          msg_data->msg["is_eof"] != true &&
          msg_data->hotwords_embedding != NULL) {
          msg_data->hotwords_embedding != nullptr) {
        LOG(INFO) << "client done";
        // if it is in final message, post the sample_data to decode
@@ -499,7 +511,9 @@
                        msg_data->msg["audio_fs"],
                        msg_data->msg["wav_format"],
                        std::ref(msg_data->tpass_online_handle),
                        std::ref(msg_data->decoder_handle)));
                        std::ref(msg_data->decoder_handle),
                        msg_data->msg["svs_lang"],
                        msg_data->msg["svs_itn"]));
            msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
        }
        catch (std::exception const &e)
@@ -532,7 +546,7 @@
          try{
            // post to decode
            if (msg_data->msg["is_eof"] != true && msg_data->hotwords_embedding != NULL) {
            if (msg_data->msg["is_eof"] != true && msg_data->hotwords_embedding != nullptr) {
              std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
              msg_data->strand_->post(
                        std::bind(&WebSocketServer::do_decoder, this,
@@ -547,7 +561,9 @@
                                  msg_data->msg["audio_fs"],
                                  msg_data->msg["wav_format"],
                                  std::ref(msg_data->tpass_online_handle),
                                  std::ref(msg_data->decoder_handle)));
                                  std::ref(msg_data->decoder_handle),
                                  msg_data->msg["svs_lang"],
                                  msg_data->msg["svs_itn"]));
              msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
            }
          }