From 21c590ad67bb00cf29c23b85666301359fb0e6e0 Mon Sep 17 00:00:00 2001
From: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Date: 星期四, 25 五月 2023 15:24:10 +0800
Subject: [PATCH] add ssl support for cpp websocket (#553)

---
 funasr/runtime/websocket/websocketclient.cpp |   66 ++++++++++++++++++++++++++------
 1 files changed, 53 insertions(+), 13 deletions(-)

diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
index 078fc5a..e9f8f1d 100644
--- a/funasr/runtime/websocket/websocketclient.cpp
+++ b/funasr/runtime/websocket/websocketclient.cpp
@@ -10,7 +10,7 @@
 #define ASIO_STANDALONE 1
 #include <websocketpp/client.hpp>
 #include <websocketpp/common/thread.hpp>
-#include <websocketpp/config/asio_no_tls_client.hpp>
+#include <websocketpp/config/asio_client.hpp>
 
 #include "audio.h"
 #include "nlohmann/json.hpp"
@@ -26,14 +26,37 @@
 #endif
 }
 typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+    context_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+context_ptr on_tls_init(websocketpp::connection_hdl) {
+  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+      asio::ssl::context::sslv23);
 
+  try {
+    ctx->set_options(
+        asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
+        asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
+
+  } catch (std::exception& e) {
+    std::cout << e.what() << std::endl;
+  }
+  return ctx;
+}
+// template for tls or not config
+template <typename T>
 class websocket_client {
  public:
-  typedef websocketpp::client<websocketpp::config::asio_client> client;
+  // typedef websocketpp::client<T> client;
+  // typedef websocketpp::client<websocketpp::config::asio_tls_client>
+  // wss_client;
   typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
 
-  websocket_client() : m_open(false), m_done(false) {
+  websocket_client(int is_ssl) : m_open(false), m_done(false) {
     // set up access channels to only log interesting things
+
     m_client.clear_access_channels(websocketpp::log::alevel::all);
     m_client.set_access_channels(websocketpp::log::alevel::connect);
     m_client.set_access_channels(websocketpp::log::alevel::disconnect);
@@ -65,10 +88,12 @@
     }
   }
   // This method will block until the connection is complete
+
   void run(const std::string& uri, const std::string& wav_path) {
     // Create a new connection to the given URI
     websocketpp::lib::error_code ec;
-    client::connection_ptr con = m_client.get_connection(uri, ec);
+    typename websocketpp::client<T>::connection_ptr con =
+        m_client.get_connection(uri, ec);
     if (ec) {
       m_client.get_alog().write(websocketpp::log::alevel::app,
                                 "Get Connection Error: " + ec.message());
@@ -84,7 +109,8 @@
     m_client.connect(con);
 
     // Create a thread to run the ASIO io_service event loop
-    websocketpp::lib::thread asio_thread(&client::run, &m_client);
+    websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
+                                         &m_client);
 
     send_wav_data();
     asio_thread.join();
@@ -201,9 +227,9 @@
                   ec);
     wait_a_bit();
   }
+  websocketpp::client<T> m_client;
 
  private:
-  client m_client;
   websocketpp::connection_hdl m_hdl;
   websocketpp::lib::mutex m_lock;
   std::string wav_path;
@@ -212,22 +238,36 @@
 };
 
 int main(int argc, char* argv[]) {
-  if (argc < 5) {
-    printf("Usage: %s server_ip port wav_path threads_num\n", argv[0]);
+  if (argc < 6) {
+    printf("Usage: %s server_ip port wav_path threads_num is_ssl\n", argv[0]);
     exit(-1);
   }
   std::string server_ip = argv[1];
   std::string port = argv[2];
   std::string wav_path = argv[3];
   int threads_num = atoi(argv[4]);
+  int is_ssl = atoi(argv[5]);
   std::vector<websocketpp::lib::thread> client_threads;
-
-  std::string uri = "ws://" + server_ip + ":" + port;
+  std::string uri = "";
+  if (is_ssl == 1) {
+    uri = "wss://" + server_ip + ":" + port;
+  } else {
+    uri = "ws://" + server_ip + ":" + port;
+  }
 
   for (size_t i = 0; i < threads_num; i++) {
-    client_threads.emplace_back([uri, wav_path]() {
-      websocket_client c;
-      c.run(uri, wav_path);
+    client_threads.emplace_back([uri, wav_path, is_ssl]() {
+      if (is_ssl == 1) {
+        websocket_client<websocketpp::config::asio_tls_client> c(is_ssl);
+
+        c.m_client.set_tls_init_handler(bind(&on_tls_init, ::_1));
+
+        c.run(uri, wav_path);
+      } else {
+        websocket_client<websocketpp::config::asio_client> c(is_ssl);
+
+        c.run(uri, wav_path);
+      }
     });
   }
 

--
Gitblit v1.9.1