zhaomingwork
2023-08-10 30f0c7ff2941ab08edb8cb257eb6cef74be42ec7
funasr/runtime/websocket/funasr-wss-client-2pass.cpp
@@ -12,18 +12,20 @@
//                     [--is-ssl <int>]  [--]
//                     [--version] [-h]
// example:
// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav
// --thread-num 1 --is-ssl 1
#define ASIO_STANDALONE 1
#include <glog/logging.h>
#include <atomic>
#include <fstream>
#include <iostream>
#include <sstream>
#include <thread>
#include <websocketpp/client.hpp>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio_client.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
#include <atomic>
#include <thread>
#include <glog/logging.h>
#include "audio.h"
#include "nlohmann/json.hpp"
@@ -51,7 +53,8 @@
}
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context> context_ptr;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
    context_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
@@ -109,20 +112,26 @@
        switch (msg->get_opcode()) {
            case websocketpp::frame::opcode::text:
                nlohmann::json jsonresult = nlohmann::json::parse(payload);
                LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
        LOG(INFO) << "Thread: " << this_thread::get_id()
                  << ",on_message = " << payload << "jsonresult" << jsonresult;
            
                // if (jsonresult["is_final"] == true){
            //    websocketpp::lib::error_code ec;
            //    m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
            //    if (ec){
                //         LOG(ERROR)<< "Error closing connection " << ec.message();
            //    }
                // }
        if (jsonresult["is_final"] == true) {
          websocketpp::lib::error_code ec;
          m_client.close(hdl, websocketpp::close::status::going_away, "", ec);
          if (ec) {
            LOG(ERROR) << "Error closing connection " << ec.message();
          }
        }
        }
    }
    // This method will block until the connection is complete  
    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string asr_mode, std::vector<int> chunk_size) {
  void run(const std::string& uri, const std::vector<string>& wav_list,
           const std::vector<string>& wav_ids, std::string asr_mode,
           std::vector<int> chunk_size) {
        // Create a new connection to the given URI
        websocketpp::lib::error_code ec;
        typename websocketpp::client<T>::connection_ptr con =
@@ -143,17 +152,12 @@
        // Create a thread to run the ASIO io_service event loop
        websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                            &m_client);
        while(true){
            int i = wav_index.fetch_add(1);
            if (i >= wav_list.size()) {
                break;
            }
            send_wav_data(wav_list[i], wav_ids[i], asr_mode, chunk_size);
        }
    send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
        WaitABit(); 
        asio_thread.join();
    }
    // The open handler will signal that we are ready to start sending data
@@ -183,7 +187,8 @@
        m_done = true;
    }
    // send wav to server
    void send_wav_data(string wav_path, string wav_id, std::string asr_mode, std::vector<int> chunk_vector) {
  void send_wav_data(string wav_path, string wav_id, std::string asr_mode,
                     std::vector<int> chunk_vector) {
        uint64_t count = 0;
        std::stringstream val;
@@ -192,15 +197,12 @@
        std::string wav_format = "pcm";
      if(IsTargetFile(wav_path.c_str(), "wav")){
         int32_t sampling_rate = -1;
         if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
            return ;
      if (!audio.LoadWav(wav_path.c_str(), &sampling_rate)) return;
      }else if(IsTargetFile(wav_path.c_str(), "pcm")){
         if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
            return ;
      if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) return;
      }else{
         wav_format = "others";
            if (!audio.LoadOthers2Char(wav_path.c_str()))
            return ;
      if (!audio.LoadOthers2Char(wav_path.c_str())) return;
      }
        float* buff;
@@ -221,6 +223,7 @@
                  break;
                }
            }
            if (wait) {
                // LOG(INFO) << "wait.." << m_open;
                WaitABit();
@@ -326,24 +329,30 @@
};
int main(int argc, char* argv[]) {
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
    TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
                                           "127.0.0.1", "string");
    TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
    TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
        "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
  TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095",
                                     "string");
  TCLAP::ValueArg<std::string> wav_path_(
      "", "wav-path",
      "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
      "asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
        true, "", "string");
    TCLAP::ValueArg<std::string>    asr_mode_("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
    TCLAP::ValueArg<std::string>    chunk_size_("", "chunk-size", "chunk_size: 5-10-5 or 5-12-5", false, "5-10-5", "string");
    TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
                                       false, 1, "int");
  TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass",
                                         false, "2pass", "string");
  TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size",
                                           "chunk_size: 5-10-5 or 5-12-5",
                                           false, "5-10-5", "string");
  TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num", false, 1,
                                   "int");
    TCLAP::ValueArg<int> is_ssl_(
        "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
        false, 1, "int");
      "", "is-ssl",
      "is-ssl is 1 means use wss connection, or use ws connection", false, 1,
      "int");
    cmd.add(server_ip_);
    cmd.add(port_);
@@ -375,7 +384,6 @@
    int threads_num = thread_num_.getValue();
    int is_ssl = is_ssl_.getValue();
    std::vector<websocketpp::lib::thread> client_threads;
    std::string uri = "";
    if (is_ssl == 1) {
        uri = "wss://" + server_ip + ":" + port;
@@ -394,8 +402,7 @@
            return 0;
        }
        string line;
        while(getline(in, line))
        {
    while (getline(in, line)) {
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
@@ -408,18 +415,30 @@
        wav_ids.emplace_back(default_id);
    }
    
  for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
    std::vector<websocketpp::lib::thread> client_threads;
    for (size_t i = 0; i < threads_num; i++) {
        client_threads.emplace_back([uri, wav_list, wav_ids, asr_mode, chunk_size, is_ssl]() {
      if (wav_i + i >= wav_list.size()) {
        break;
      }
      std::vector<string> tmp_wav_list;
      std::vector<string> tmp_wav_ids;
      tmp_wav_list.emplace_back(wav_list[wav_i + i]);
      tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
      client_threads.emplace_back(
          [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
          if (is_ssl == 1) {
            WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
            c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
            c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
              c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
          } else {
            WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
            c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
              c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
          }
        });
    }
@@ -428,3 +447,4 @@
        t.join();
    }
}
}