游雁
2023-05-25 b18f7d121f2f17df8bf2d0c2bbb223bc5ddbcc0f
funasr/runtime/websocket/websocketclient.cpp
@@ -10,9 +10,10 @@
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio_no_tls_client.hpp>
#include <websocketpp/config/asio_client.hpp>
#include "audio.h"
#include "nlohmann/json.hpp"
/**
 * Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -25,14 +26,37 @@
#endif
}
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
    context_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
context_ptr on_tls_init(websocketpp::connection_hdl) {
  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
      asio::ssl::context::sslv23);
  try {
    ctx->set_options(
        asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
        asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
  } catch (std::exception& e) {
    std::cout << e.what() << std::endl;
  }
  return ctx;
}
// template for tls or not config
template <typename T>
class websocket_client {
 public:
  typedef websocketpp::client<websocketpp::config::asio_client> client;
  // typedef websocketpp::client<T> client;
  // typedef websocketpp::client<websocketpp::config::asio_tls_client>
  // wss_client;
  typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
  websocket_client() : m_open(false), m_done(false) {
  websocket_client(int is_ssl) : m_open(false), m_done(false) {
    // set up access channels to only log interesting things
    m_client.clear_access_channels(websocketpp::log::alevel::all);
    m_client.set_access_channels(websocketpp::log::alevel::connect);
    m_client.set_access_channels(websocketpp::log::alevel::disconnect);
@@ -64,10 +88,12 @@
    }
  }
  // This method will block until the connection is complete
  void run(const std::string& uri, const std::string& wav_path) {
    // Create a new connection to the given URI
    websocketpp::lib::error_code ec;
    client::connection_ptr con = m_client.get_connection(uri, ec);
    typename websocketpp::client<T>::connection_ptr con =
        m_client.get_connection(uri, ec);
    if (ec) {
      m_client.get_alog().write(websocketpp::log::alevel::app,
                                "Get Connection Error: " + ec.message());
@@ -83,7 +109,8 @@
    m_client.connect(con);
    // Create a thread to run the ASIO io_service event loop
    websocketpp::lib::thread asio_thread(&client::run, &m_client);
    websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                         &m_client);
    send_wav_data();
    asio_thread.join();
@@ -120,7 +147,7 @@
    uint64_t count = 0;
    std::stringstream val;
    Audio audio(1);
    funasr::Audio audio(1);
    int32_t sampling_rate = 16000;
    if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) {
@@ -156,6 +183,19 @@
      }
    }
    websocketpp::lib::error_code ec;
    nlohmann::json jsonbegin;
    nlohmann::json chunk_size = nlohmann::json::array();
    chunk_size.push_back(5);
    chunk_size.push_back(0);
    chunk_size.push_back(5);
    jsonbegin["chunk_size"] = chunk_size;
    jsonbegin["chunk_interval"] = 10;
    jsonbegin["wav_name"] = "damo";
    jsonbegin["is_speaking"] = true;
    m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                  ec);
    // fetch wav data use asr engine api
    while (audio.Fetch(buff, len, flag) > 0) {
      short iArray[len];
@@ -181,13 +221,15 @@
      wait_a_bit();
    }
    m_client.send(m_hdl, "Done", websocketpp::frame::opcode::text, ec);
    nlohmann::json jsonresult;
    jsonresult["is_speaking"] = false;
    m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                  ec);
    wait_a_bit();
  }
  websocketpp::client<T> m_client;
 private:
  client m_client;
  websocketpp::connection_hdl m_hdl;
  websocketpp::lib::mutex m_lock;
  std::string wav_path;
@@ -196,22 +238,36 @@
};
int main(int argc, char* argv[]) {
  if (argc < 5) {
    printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
  if (argc < 6) {
    printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
    exit(-1);
  }
  std::string server_ip = argv[1];
  std::string port = argv[2];
  std::string wav_path = argv[3];
  int threads_num = atoi(argv[4]);
  int is_ssl = atoi(argv[5]);
  std::vector<websocketpp::lib::thread> client_threads;
  std::string uri = "ws://" + server_ip + ":" + port;
  std::string uri = "";
  if (is_ssl == 1) {
    uri = "wss://" + server_ip + ":" + port;
  } else {
    uri = "ws://" + server_ip + ":" + port;
  }
  for (size_t i = 0; i < threads_num; i++) {
    client_threads.emplace_back([uri, wav_path]() {
      websocket_client c;
      c.run(uri, wav_path);
    client_threads.emplace_back([uri, wav_path, is_ssl]() {
      if (is_ssl == 1) {
        websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
        c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
        c.run(uri, wav_path);
      } else {
        websocket_client<websocketpp::config::asio_client> c(is_ssl);
        c.run(uri, wav_path);
      }
    });
  }