zhaomingwork
2023-05-25 21c590ad67bb00cf29c23b85666301359fb0e6e0
add ssl support for cpp websocket (#553)

6个文件已修改
299 ■■■■ 已修改文件
funasr/runtime/websocket/CMakeLists.txt 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/readme.md 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketclient.cpp 66 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketmain.cpp 65 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketsrv.cpp 62 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketsrv.h 76 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/CMakeLists.txt
@@ -55,10 +55,12 @@
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
add_executable(websocketmain "websocketmain.cpp" "websocketsrv.cpp")
add_executable(websocketclient "websocketclient.cpp")
target_link_libraries(websocketclient PUBLIC funasr)
target_link_libraries(websocketmain PUBLIC funasr)
target_link_libraries(websocketclient PUBLIC funasr ssl crypto)
target_link_libraries(websocketmain PUBLIC funasr ssl crypto)
funasr/runtime/websocket/readme.md
@@ -33,7 +33,12 @@
```
### Build runtime
required openssl lib
```shell
#install openssl lib first
apt-get install libssl-dev
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
mkdir build && cd build
cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
@@ -43,11 +48,12 @@
```shell
cd bin
./websocketmain  [--model_thread_num <int>] [--decoder_thread_num <int>]
   ./websocketmain  [--model_thread_num <int>] [--decoder_thread_num <int>]
                    [--io_thread_num <int>] [--port <int>] [--listen_ip
                    <string>] [--punc-quant <string>] [--punc-dir <string>]
                    [--vad-quant <string>] [--vad-dir <string>] [--quantize
                    <string>] --model-dir <string> [--] [--version] [-h]
                    <string>] --model-dir <string> [--keyfile <string>]
                    [--certfile <string>] [--] [--version] [-h]
Where:
   --model-dir <string>
     (required)  the asr model path, which contains model.onnx, config.yaml, am.mvn
@@ -70,6 +76,10 @@
     number of threads for network io, default:8
   --port <int>
     listen port, default:8889
   --certfile <string>
     path of certficate for WSS connection. if it is empty, it will be in WS mode.
   --keyfile <string>
     path of keyfile for WSS connection
  
   Required:  --model-dir <string>
   If use vad, please add: --vad-dir <string>
@@ -81,14 +91,16 @@
## Run websocket client test
```shell
Usage: websocketclient server_ip port wav_path threads_num
Usage: ./websocketclient server_ip port wav_path threads_num is_ssl
is_ssl is 1 means use wss connection, or use ws connection
example:
websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64
websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64 0
result json, example like:
{"text":"一二三四五六七八九十一二三四五六七八九十"}
{"mode":"offline","text":"欢迎大家来体验达摩院推出的语音识别模型","wav_name":"wav2"}
```
funasr/runtime/websocket/websocketclient.cpp
@@ -10,7 +10,7 @@
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio_no_tls_client.hpp>
#include <websocketpp/config/asio_client.hpp>
#include "audio.h"
#include "nlohmann/json.hpp"
@@ -26,14 +26,37 @@
#endif
}
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
    context_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
context_ptr on_tls_init(websocketpp::connection_hdl) {
  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
      asio::ssl::context::sslv23);
  try {
    ctx->set_options(
        asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
        asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
  } catch (std::exception& e) {
    std::cout << e.what() << std::endl;
  }
  return ctx;
}
// template for tls or not config
template <typename T>
class websocket_client {
 public:
  typedef websocketpp::client<websocketpp::config::asio_client> client;
  // typedef websocketpp::client<T> client;
  // typedef websocketpp::client<websocketpp::config::asio_tls_client>
  // wss_client;
  typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
  websocket_client() : m_open(false), m_done(false) {
  websocket_client(int is_ssl) : m_open(false), m_done(false) {
    // set up access channels to only log interesting things
    m_client.clear_access_channels(websocketpp::log::alevel::all);
    m_client.set_access_channels(websocketpp::log::alevel::connect);
    m_client.set_access_channels(websocketpp::log::alevel::disconnect);
@@ -65,10 +88,12 @@
    }
  }
  // This method will block until the connection is complete
  void run(const std::string& uri, const std::string& wav_path) {
    // Create a new connection to the given URI
    websocketpp::lib::error_code ec;
    client::connection_ptr con = m_client.get_connection(uri, ec);
    typename websocketpp::client<T>::connection_ptr con =
        m_client.get_connection(uri, ec);
    if (ec) {
      m_client.get_alog().write(websocketpp::log::alevel::app,
                                "Get Connection Error: " + ec.message());
@@ -84,7 +109,8 @@
    m_client.connect(con);
    // Create a thread to run the ASIO io_service event loop
    websocketpp::lib::thread asio_thread(&client::run, &m_client);
    websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                         &m_client);
    send_wav_data();
    asio_thread.join();
@@ -201,9 +227,9 @@
                  ec);
    wait_a_bit();
  }
  websocketpp::client<T> m_client;
 private:
  client m_client;
  websocketpp::connection_hdl m_hdl;
  websocketpp::lib::mutex m_lock;
  std::string wav_path;
@@ -212,22 +238,36 @@
};
int main(int argc, char* argv[]) {
  if (argc < 5) {
    printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
  if (argc < 6) {
    printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
    exit(-1);
  }
  std::string server_ip = argv[1];
  std::string port = argv[2];
  std::string wav_path = argv[3];
  int threads_num = atoi(argv[4]);
  int is_ssl = atoi(argv[5]);
  std::vector<websocketpp::lib::thread> client_threads;
  std::string uri = "ws://" + server_ip + ":" + port;
  std::string uri = "";
  if (is_ssl == 1) {
    uri = "wss://" + server_ip + ":" + port;
  } else {
    uri = "ws://" + server_ip + ":" + port;
  }
  for (size_t i = 0; i < threads_num; i++) {
    client_threads.emplace_back([uri, wav_path]() {
      websocket_client c;
      c.run(uri, wav_path);
    client_threads.emplace_back([uri, wav_path, is_ssl]() {
      if (is_ssl == 1) {
        websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
        c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
        c.run(uri, wav_path);
      } else {
        websocket_client<websocketpp::config::asio_client> c(is_ssl);
        c.run(uri, wav_path);
      }
    });
  }
funasr/runtime/websocket/websocketmain.cpp
@@ -64,6 +64,14 @@
    TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
                                          "model_thread_num", false, 1, "int");
    TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
                                          "string");
    TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
                                         "string");
    cmd.add(certfile);
    cmd.add(keyfile);
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(vad_dir);
@@ -97,6 +105,14 @@
    std::vector<std::thread> decoder_threads;
    std::string s_certfile = certfile.getValue();
    std::string s_keyfile = keyfile.getValue();
    bool is_ssl = false;
    if (!s_certfile.empty()) {
      is_ssl = true;
    }
    auto conn_guard = asio::make_work_guard(
        io_decoder);  // make sure threads can wait in the queue
@@ -105,30 +121,55 @@
      decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
    }
    server server_;       // server for websocket
    server_.init_asio();  // init asio
    server_.set_reuse_addr(
        true);  // reuse address as we create multiple threads
    server server_;  // server for websocket
    wss_server wss_server_;
    if (is_ssl) {
      wss_server_.init_asio();  // init asio
      wss_server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
    // list on port for accept
    server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
      // list on port for accept
      wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
      WebSocketServer websocket_srv(
          io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
          s_keyfile);  // websocket server for asr engine
      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    WebSocketServer websocket_srv(io_decoder,
                                  &server_);  // websocket server for asr engine
    websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    } else {
      server_.init_asio();  // init asio
      server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
      // list on port for accept
      server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
      WebSocketServer websocket_srv(
          io_decoder, is_ssl, &server_, nullptr, s_certfile,
          s_keyfile);  // websocket server for asr engine
      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    }
    std::cout << "asr model init finished. listen on port:" << s_port
              << std::endl;
    // Start the ASIO network io_service run loop
    if (s_io_thread_num == 1) {
      server_.run();
      if (is_ssl) {
        wss_server_.run();
      } else {
        server_.run();
      }
    } else {
      typedef websocketpp::lib::shared_ptr<websocketpp::lib::thread> thread_ptr;
      std::vector<thread_ptr> ts;
      // create threads for io network
      for (size_t i = 0; i < s_io_thread_num; i++) {
        ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
            &server::run, &server_));
        if (is_ssl) {
          ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
              &wss_server::run, &wss_server_));
        } else {
          ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
              &server::run, &server_));
        }
      }
      // wait for theads
      for (size_t i = 0; i < s_io_thread_num; i++) {
funasr/runtime/websocket/websocketsrv.cpp
@@ -16,6 +16,44 @@
#include <utility>
#include <vector>
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                         websocketpp::connection_hdl hdl,
                                         std::string& s_certfile,
                                         std::string& s_keyfile) {
  namespace asio = websocketpp::lib::asio;
  std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
  std::cout << "using TLS mode: "
            << (mode == MOZILLA_MODERN ? "Mozilla Modern"
                                       : "Mozilla Intermediate")
            << std::endl;
  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
      asio::ssl::context::sslv23);
  try {
    if (mode == MOZILLA_MODERN) {
      // Modern disables TLSv1
      ctx->set_options(
          asio::ssl::context::default_workarounds |
          asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
          asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
    } else {
      ctx->set_options(asio::ssl::context::default_workarounds |
                       asio::ssl::context::no_sslv2 |
                       asio::ssl::context::no_sslv3 |
                       asio::ssl::context::single_dh_use);
    }
    ctx->use_certificate_chain_file(s_certfile);
    ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
  } catch (std::exception& e) {
    std::cout << "Exception: " << e.what() << std::endl;
  }
  return ctx;
}
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(const std::vector<char>& buffer,
                                 websocketpp::connection_hdl& hdl,
@@ -40,8 +78,13 @@
      jsonresult["wav_name"] = msg["wav_name"];
      // send the json to client
      server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                    ec);
      if (is_ssl) {
        wss_server_->send(hdl, jsonresult.dump(),
                          websocketpp::frame::opcode::text, ec);
      } else {
        server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                      ec);
      }
      std::cout << "buffer.size=" << buffer.size()
                << ",result json=" << jsonresult.dump() << std::endl;
@@ -83,10 +126,19 @@
  auto iter = data_map.begin();
  while (iter != data_map.end()) {  // loop to find closed connection
    websocketpp::connection_hdl hdl = iter->first;
    server::connection_ptr con = server_->get_con_from_hdl(hdl);
    if (con->get_state() != 1) {  // session::state::open ==1
      to_remove.push_back(hdl);
    if (is_ssl) {
      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
      if (con->get_state() != 1) {  // session::state::open ==1
        to_remove.push_back(hdl);
      }
    } else {
      server::connection_ptr con = server_->get_con_from_hdl(hdl);
      if (con->get_state() != 1) {  // session::state::open ==1
        to_remove.push_back(hdl);
      }
    }
    iter++;
  }
  for (auto hdl : to_remove) {
funasr/runtime/websocket/websocketsrv.h
@@ -25,7 +25,7 @@
#include <fstream>
#include <functional>
#include <websocketpp/common/thread.hpp>
#include <websocketpp/config/asio_no_tls.hpp>
#include <websocketpp/config/asio.hpp>
#include <websocketpp/server.hpp>
#include "asio.hpp"
@@ -34,12 +34,16 @@
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
typedef websocketpp::server<websocketpp::config::asio> server;
typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
typedef server::message_ptr message_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
    context_ptr;
typedef struct {
  std::string msg;
@@ -51,25 +55,55 @@
  std::shared_ptr<std::vector<char>> samples;
} FUNASR_MESSAGE;
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
// the TLS modes. The code below demonstrates how to implement both the modern
enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
class WebSocketServer {
 public:
  WebSocketServer(asio::io_context& io_decoder, server* server_)
      : io_decoder_(io_decoder), server_(server_) {
    // set message handle
    server_->set_message_handler(
        [this](websocketpp::connection_hdl hdl, message_ptr msg) {
          on_message(hdl, msg);
        });
    // set open handle
    server_->set_open_handler(
        [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
    // set close handle
    server_->set_close_handler(
        [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
    // begin accept
    server_->start_accept();
    // not print log
    server_->clear_access_channels(websocketpp::log::alevel::all);
  WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
                  wss_server* wss_server, std::string& s_certfile,
                  std::string& s_keyfile)
      : io_decoder_(io_decoder),
        is_ssl(is_ssl),
        server_(server),
        wss_server_(wss_server) {
    if (is_ssl) {
      std::cout << "certfile path is " << s_certfile << std::endl;
      wss_server->set_tls_init_handler(
          bind<context_ptr>(&WebSocketServer::on_tls_init, this,
                            MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
      wss_server_->set_message_handler(
          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
            on_message(hdl, msg);
          });
      // set open handle
      wss_server_->set_open_handler(
          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
      // set close handle
      wss_server_->set_close_handler(
          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
      // begin accept
      wss_server_->start_accept();
      // not print log
      wss_server_->clear_access_channels(websocketpp::log::alevel::all);
    } else {
      // set message handle
      server_->set_message_handler(
          [this](websocketpp::connection_hdl hdl, message_ptr msg) {
            on_message(hdl, msg);
          });
      // set open handle
      server_->set_open_handler(
          [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
      // set close handle
      server_->set_close_handler(
          [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
      // begin accept
      server_->start_accept();
      // not print log
      server_->clear_access_channels(websocketpp::log::alevel::all);
    }
  }
  void do_decoder(const std::vector<char>& buffer,
                  websocketpp::connection_hdl& hdl, const nlohmann::json& msg);
@@ -78,6 +112,8 @@
  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
  void on_open(websocketpp::connection_hdl hdl);
  void on_close(websocketpp::connection_hdl hdl);
  context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
                          std::string& s_certfile, std::string& s_keyfile);
 private:
  void check_and_clean_connection();
@@ -85,7 +121,9 @@
  // std::ofstream fout;
  FUNASR_HANDLE asr_hanlde;  // asr engine handle
  bool isonline = false;  // online or offline engine, now only support offline
  server* server_;        // websocket server
  bool is_ssl = true;
  server* server_;          // websocket server
  wss_server* wss_server_;  // websocket server
  // use map to keep the received samples data from one connection in offline
  // engine. if for online engline, a data struct is needed(TODO)