From 21c590ad67bb00cf29c23b85666301359fb0e6e0 Mon Sep 17 00:00:00 2001
From: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Date: 星期四, 25 五月 2023 15:24:10 +0800
Subject: [PATCH] add ssl support for cpp websocket (#553)
---
funasr/runtime/websocket/CMakeLists.txt | 8 +
funasr/runtime/websocket/websocketsrv.cpp | 62 +++++++++++-
funasr/runtime/websocket/readme.md | 22 +++-
funasr/runtime/websocket/websocketsrv.h | 76 +++++++++++---
funasr/runtime/websocket/websocketclient.cpp | 66 ++++++++++--
funasr/runtime/websocket/websocketmain.cpp | 65 ++++++++++--
6 files changed, 242 insertions(+), 57 deletions(-)
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index e89537b..8217b30 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -55,10 +55,12 @@
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
-
+
+# install openssl first apt-get install libssl-dev
+find_package(OpenSSL REQUIRED)
add_executable(websocketmain "websocketmain.cpp" "websocketsrv.cpp")
add_executable(websocketclient "websocketclient.cpp")
-target_link_libraries(websocketclient PUBLIC funasr)
-target_link_libraries(websocketmain PUBLIC funasr)
+target_link_libraries(websocketclient PUBLIC funasr ssl crypto)
+target_link_libraries(websocketmain PUBLIC funasr ssl crypto)
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index 078184e..d2a54e9 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -33,7 +33,12 @@
```
### Build runtime
+required openssl lib
+
```shell
+#install openssl lib first
+apt-get install libssl-dev
+
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
@@ -43,11 +48,12 @@
```shell
cd bin
-./websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
+ ./websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
[--io_thread_num <int>] [--port <int>] [--listen_ip
<string>] [--punc-quant <string>] [--punc-dir <string>]
[--vad-quant <string>] [--vad-dir <string>] [--quantize
- <string>] --model-dir <string> [--] [--version] [-h]
+ <string>] --model-dir <string> [--keyfile <string>]
+ [--certfile <string>] [--] [--version] [-h]
Where:
--model-dir <string>
(required) the asr model path, which contains model.onnx, config.yaml, am.mvn
@@ -70,6 +76,10 @@
number of threads for network io, default:8
--port <int>
listen port, default:8889
+ --certfile <string>
+ path of certficate for WSS connection. if it is empty, it will be in WS mode.
+ --keyfile <string>
+ path of keyfile for WSS connection
Required: --model-dir <string>
If use vad, please add: --vad-dir <string>
@@ -81,14 +91,16 @@
## Run websocket client test
```shell
-Usage: websocketclient server_ip port wav_path threads_num
+Usage: ./websocketclient server_ip port wav_path threads_num is_ssl
+
+is_ssl is 1 means use wss connection, or use ws connection
example:
-websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64
+websocketclient 127.0.0.1 8889 funasr/runtime/websocket/test.pcm.wav 64 0
result json, example like:
-{"text":"涓�浜屼笁鍥涗簲鍏竷鍏節鍗佷竴浜屼笁鍥涗簲鍏竷鍏節鍗�"}
+{"mode":"offline","text":"娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�","wav_name":"wav2"}
```
diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
index 078fc5a..e9f8f1d 100644
--- a/funasr/runtime/websocket/websocketclient.cpp
+++ b/funasr/runtime/websocket/websocketclient.cpp
@@ -10,7 +10,7 @@
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
#include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_no_tls_client.hpp>
+#include <websocketpp/config/asio_client.hpp>
#include "audio.h"
#include "nlohmann/json.hpp"
@@ -26,14 +26,37 @@
#endif
}
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+ context_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+context_ptr on_tls_init(websocketpp::connection_hdl) {
+ context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+ asio::ssl::context::sslv23);
+ try {
+ ctx->set_options(
+ asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
+ asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
+
+ } catch (std::exception& e) {
+ std::cout << e.what() << std::endl;
+ }
+ return ctx;
+}
+// template for tls or not config
+template <typename T>
class websocket_client {
public:
- typedef websocketpp::client<websocketpp::config::asio_client> client;
+ // typedef websocketpp::client<T> client;
+ // typedef websocketpp::client<websocketpp::config::asio_tls_client>
+ // wss_client;
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
- websocket_client() : m_open(false), m_done(false) {
+ websocket_client(int is_ssl) : m_open(false), m_done(false) {
// set up access channels to only log interesting things
+
m_client.clear_access_channels(websocketpp::log::alevel::all);
m_client.set_access_channels(websocketpp::log::alevel::connect);
m_client.set_access_channels(websocketpp::log::alevel::disconnect);
@@ -65,10 +88,12 @@
}
}
// This method will block until the connection is complete
+
void run(const std::string& uri, const std::string& wav_path) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
- client::connection_ptr con = m_client.get_connection(uri, ec);
+ typename websocketpp::client<T>::connection_ptr con =
+ m_client.get_connection(uri, ec);
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Get Connection Error: " + ec.message());
@@ -84,7 +109,8 @@
m_client.connect(con);
// Create a thread to run the ASIO io_service event loop
- websocketpp::lib::thread asio_thread(&client::run, &m_client);
+ websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
+ &m_client);
send_wav_data();
asio_thread.join();
@@ -201,9 +227,9 @@
ec);
wait_a_bit();
}
+ websocketpp::client<T> m_client;
private:
- client m_client;
websocketpp::connection_hdl m_hdl;
websocketpp::lib::mutex m_lock;
std::string wav_path;
@@ -212,22 +238,36 @@
};
int main(int argc, char* argv[]) {
- if (argc < 5) {
- printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
+ if (argc < 6) {
+ printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
exit(-1);
}
std::string server_ip = argv[1];
std::string port = argv[2];
std::string wav_path = argv[3];
int threads_num = atoi(argv[4]);
+ int is_ssl = atoi(argv[5]);
std::vector<websocketpp::lib::thread> client_threads;
-
- std::string uri = "ws://" + server_ip + ":" + port;
+ std::string uri = "";
+ if (is_ssl == 1) {
+ uri = "wss://" + server_ip + ":" + port;
+ } else {
+ uri = "ws://" + server_ip + ":" + port;
+ }
for (size_t i = 0; i < threads_num; i++) {
- client_threads.emplace_back([uri, wav_path]() {
- websocket_client c;
- c.run(uri, wav_path);
+ client_threads.emplace_back([uri, wav_path, is_ssl]() {
+ if (is_ssl == 1) {
+ websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
+
+ c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
+
+ c.run(uri, wav_path);
+ } else {
+ websocket_client<websocketpp::config::asio_client> c(is_ssl);
+
+ c.run(uri, wav_path);
+ }
});
}
diff --git a/funasr/runtime/websocket/websocketmain.cpp b/funasr/runtime/websocket/websocketmain.cpp
index 4614b51..306c3f0 100644
--- a/funasr/runtime/websocket/websocketmain.cpp
+++ b/funasr/runtime/websocket/websocketmain.cpp
@@ -64,6 +64,14 @@
TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
"model_thread_num", false, 1, "int");
+ TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
+ "string");
+ TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
+ "string");
+
+ cmd.add(certfile);
+ cmd.add(keyfile);
+
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
@@ -97,6 +105,14 @@
std::vector<std::thread> decoder_threads;
+ std::string s_certfile = certfile.getValue();
+ std::string s_keyfile = keyfile.getValue();
+
+ bool is_ssl = false;
+ if (!s_certfile.empty()) {
+ is_ssl = true;
+ }
+
auto conn_guard = asio::make_work_guard(
io_decoder); // make sure threads can wait in the queue
@@ -105,30 +121,55 @@
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
}
- server server_; // server for websocket
- server_.init_asio(); // init asio
- server_.set_reuse_addr(
- true); // reuse address as we create multiple threads
+ server server_; // server for websocket
+ wss_server wss_server_;
+ if (is_ssl) {
+ wss_server_.init_asio(); // init asio
+ wss_server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
- // list on port for accept
- server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ // list on port for accept
+ wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
- WebSocketServer websocket_srv(io_decoder,
- &server_); // websocket server for asr engine
- websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ } else {
+ server_.init_asio(); // init asio
+ server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, &server_, nullptr, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ }
+
std::cout << "asr model init finished. listen on port:" << s_port
<< std::endl;
// Start the ASIO network io_service run loop
if (s_io_thread_num == 1) {
- server_.run();
+ if (is_ssl) {
+ wss_server_.run();
+ } else {
+ server_.run();
+ }
} else {
typedef websocketpp::lib::shared_ptr<websocketpp::lib::thread> thread_ptr;
std::vector<thread_ptr> ts;
// create threads for io network
for (size_t i = 0; i < s_io_thread_num; i++) {
- ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
- &server::run, &server_));
+ if (is_ssl) {
+ ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
+ &wss_server::run, &wss_server_));
+ } else {
+ ts.push_back(websocketpp::lib::make_shared<websocketpp::lib::thread>(
+ &server::run, &server_));
+ }
}
// wait for theads
for (size_t i = 0; i < s_io_thread_num; i++) {
diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocketsrv.cpp
index b81442c..eb3c8db 100644
--- a/funasr/runtime/websocket/websocketsrv.cpp
+++ b/funasr/runtime/websocket/websocketsrv.cpp
@@ -16,6 +16,44 @@
#include <utility>
#include <vector>
+context_ptr WebSocketServer::on_tls_init(tls_mode mode,
+ websocketpp::connection_hdl hdl,
+ std::string& s_certfile,
+ std::string& s_keyfile) {
+ namespace asio = websocketpp::lib::asio;
+
+ std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
+ std::cout << "using TLS mode: "
+ << (mode == MOZILLA_MODERN ? "Mozilla Modern"
+ : "Mozilla Intermediate")
+ << std::endl;
+
+ context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+ asio::ssl::context::sslv23);
+
+ try {
+ if (mode == MOZILLA_MODERN) {
+ // Modern disables TLSv1
+ ctx->set_options(
+ asio::ssl::context::default_workarounds |
+ asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
+ } else {
+ ctx->set_options(asio::ssl::context::default_workarounds |
+ asio::ssl::context::no_sslv2 |
+ asio::ssl::context::no_sslv3 |
+ asio::ssl::context::single_dh_use);
+ }
+
+ ctx->use_certificate_chain_file(s_certfile);
+ ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
+
+ } catch (std::exception& e) {
+ std::cout << "Exception: " << e.what() << std::endl;
+ }
+ return ctx;
+}
+
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(const std::vector<char>& buffer,
websocketpp::connection_hdl& hdl,
@@ -40,8 +78,13 @@
jsonresult["wav_name"] = msg["wav_name"];
// send the json to client
- server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
- ec);
+ if (is_ssl) {
+ wss_server_->send(hdl, jsonresult.dump(),
+ websocketpp::frame::opcode::text, ec);
+ } else {
+ server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+ }
std::cout << "buffer.size=" << buffer.size()
<< ",result json=" << jsonresult.dump() << std::endl;
@@ -83,10 +126,19 @@
auto iter = data_map.begin();
while (iter != data_map.end()) { // loop to find closed connection
websocketpp::connection_hdl hdl = iter->first;
- server::connection_ptr con = server_->get_con_from_hdl(hdl);
- if (con->get_state() != 1) { // session::state::open ==1
- to_remove.push_back(hdl);
+
+ if (is_ssl) {
+ wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ } else {
+ server::connection_ptr con = server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
}
+
iter++;
}
for (auto hdl : to_remove) {
diff --git a/funasr/runtime/websocket/websocketsrv.h b/funasr/runtime/websocket/websocketsrv.h
index 82d717e..3cb8816 100644
--- a/funasr/runtime/websocket/websocketsrv.h
+++ b/funasr/runtime/websocket/websocketsrv.h
@@ -25,7 +25,7 @@
#include <fstream>
#include <functional>
#include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_no_tls.hpp>
+#include <websocketpp/config/asio.hpp>
#include <websocketpp/server.hpp>
#include "asio.hpp"
@@ -34,12 +34,16 @@
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
typedef websocketpp::server<websocketpp::config::asio> server;
+typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
typedef server::message_ptr message_ptr;
using websocketpp::lib::bind;
using websocketpp::lib::placeholders::_1;
using websocketpp::lib::placeholders::_2;
+
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+ context_ptr;
typedef struct {
std::string msg;
@@ -51,25 +55,55 @@
std::shared_ptr<std::vector<char>> samples;
} FUNASR_MESSAGE;
+// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
+// the TLS modes. The code below demonstrates how to implement both the modern
+enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
class WebSocketServer {
public:
- WebSocketServer(asio::io_context& io_decoder, server* server_)
- : io_decoder_(io_decoder), server_(server_) {
- // set message handle
- server_->set_message_handler(
- [this](websocketpp::connection_hdl hdl, message_ptr msg) {
- on_message(hdl, msg);
- });
- // set open handle
- server_->set_open_handler(
- [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
- // set close handle
- server_->set_close_handler(
- [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
- // begin accept
- server_->start_accept();
- // not print log
- server_->clear_access_channels(websocketpp::log::alevel::all);
+ WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
+ wss_server* wss_server, std::string& s_certfile,
+ std::string& s_keyfile)
+ : io_decoder_(io_decoder),
+ is_ssl(is_ssl),
+ server_(server),
+ wss_server_(wss_server) {
+ if (is_ssl) {
+ std::cout << "certfile path is " << s_certfile << std::endl;
+ wss_server->set_tls_init_handler(
+ bind<context_ptr>(&WebSocketServer::on_tls_init, this,
+ MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
+ wss_server_->set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+ // set open handle
+ wss_server_->set_open_handler(
+ [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+ // set close handle
+ wss_server_->set_close_handler(
+ [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+ // begin accept
+ wss_server_->start_accept();
+ // not print log
+ wss_server_->clear_access_channels(websocketpp::log::alevel::all);
+
+ } else {
+ // set message handle
+ server_->set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+ // set open handle
+ server_->set_open_handler(
+ [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+ // set close handle
+ server_->set_close_handler(
+ [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+ // begin accept
+ server_->start_accept();
+ // not print log
+ server_->clear_access_channels(websocketpp::log::alevel::all);
+ }
}
void do_decoder(const std::vector<char>& buffer,
websocketpp::connection_hdl& hdl, const nlohmann::json& msg);
@@ -78,6 +112,8 @@
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
void on_open(websocketpp::connection_hdl hdl);
void on_close(websocketpp::connection_hdl hdl);
+ context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
+ std::string& s_certfile, std::string& s_keyfile);
private:
void check_and_clean_connection();
@@ -85,7 +121,9 @@
// std::ofstream fout;
FUNASR_HANDLE asr_hanlde; // asr engine handle
bool isonline = false; // online or offline engine, now only support offline
- server* server_; // websocket server
+ bool is_ssl = true;
+ server* server_; // websocket server
+ wss_server* wss_server_; // websocket server
// use map to keep the received samples data from one connection in offline
// engine. if for online engline, a data struct is needed(TODO)
--
Gitblit v1.9.1