From 4ce2e2d76c66ddecd6903f4ffce98eaba675c685 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期三, 31 五月 2023 18:42:38 +0800
Subject: [PATCH] Merge pull request #571 from alibaba-damo-academy/dev_lhn

---
 funasr/runtime/websocket/websocketsrv.cpp |   85 ++++++++++++++++++++++++++++++++++--------
 1 files changed, 68 insertions(+), 17 deletions(-)

diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocketsrv.cpp
index 598ad3d..eb3c8db 100644
--- a/funasr/runtime/websocket/websocketsrv.cpp
+++ b/funasr/runtime/websocket/websocketsrv.cpp
@@ -16,9 +16,48 @@
 #include <utility>
 #include <vector>
 
+context_ptr WebSocketServer::on_tls_init(tls_mode mode,
+                                         websocketpp::connection_hdl hdl,
+                                         std::string& s_certfile,
+                                         std::string& s_keyfile) {
+  namespace asio = websocketpp::lib::asio;
+
+  std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
+  std::cout << "using TLS mode: "
+            << (mode == MOZILLA_MODERN ? "Mozilla Modern"
+                                       : "Mozilla Intermediate")
+            << std::endl;
+
+  context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+      asio::ssl::context::sslv23);
+
+  try {
+    if (mode == MOZILLA_MODERN) {
+      // Modern disables TLSv1
+      ctx->set_options(
+          asio::ssl::context::default_workarounds |
+          asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+          asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
+    } else {
+      ctx->set_options(asio::ssl::context::default_workarounds |
+                       asio::ssl::context::no_sslv2 |
+                       asio::ssl::context::no_sslv3 |
+                       asio::ssl::context::single_dh_use);
+    }
+
+    ctx->use_certificate_chain_file(s_certfile);
+    ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
+
+  } catch (std::exception& e) {
+    std::cout << "Exception: " << e.what() << std::endl;
+  }
+  return ctx;
+}
+
 // feed buffer to asr engine for decoder
 void WebSocketServer::do_decoder(const std::vector<char>& buffer,
-                                 websocketpp::connection_hdl& hdl) {
+                                 websocketpp::connection_hdl& hdl,
+                                 const nlohmann::json& msg) {
   try {
     int num_samples = buffer.size();  // the size of the buf
 
@@ -35,17 +74,17 @@
       nlohmann::json jsonresult;        // result json
       jsonresult["text"] = asr_result;  // put result in 'text'
       jsonresult["mode"] = "offline";
-      std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
-      auto it_data = data_map.find(hdl);
-      if (it_data != data_map.end()) {
-        msg_data = it_data->second;
-      }
 
-      jsonresult["wav_name"] = msg_data->msg["wav_name"];
+      jsonresult["wav_name"] = msg["wav_name"];
 
       // send the json to client
-      server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
-                    ec);
+      if (is_ssl) {
+        wss_server_->send(hdl, jsonresult.dump(),
+                          websocketpp::frame::opcode::text, ec);
+      } else {
+        server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+                      ec);
+      }
 
       std::cout << "buffer.size=" << buffer.size()
                 << ",result json=" << jsonresult.dump() << std::endl;
@@ -87,10 +126,19 @@
   auto iter = data_map.begin();
   while (iter != data_map.end()) {  // loop to find closed connection
     websocketpp::connection_hdl hdl = iter->first;
-    server::connection_ptr con = server_->get_con_from_hdl(hdl);
-    if (con->get_state() != 1) {  // session::state::open ==1
-      to_remove.push_back(hdl);
+
+    if (is_ssl) {
+      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
+    } else {
+      server::connection_ptr con = server_->get_con_from_hdl(hdl);
+      if (con->get_state() != 1) {  // session::state::open ==1
+        to_remove.push_back(hdl);
+      }
     }
+
     iter++;
   }
   for (auto hdl : to_remove) {
@@ -125,6 +173,7 @@
       if (jsonresult["wav_name"] != nullptr) {
         msg_data->msg["wav_name"] = jsonresult["wav_name"];
       }
+
       if (jsonresult["is_speaking"] == false ||
           jsonresult["is_finished"] == true) {
         std::cout << "client done" << std::endl;
@@ -137,9 +186,10 @@
           sample_data_p->insert(sample_data_p->end(), padding.data(),
                                 padding.data() + padding.size());
           // for offline, send all receive data to decoder engine
-          asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
-                                            std::move(*(sample_data_p.get())),
-                                            std::move(hdl)));
+          asio::post(io_decoder_,
+                     std::bind(&WebSocketServer::do_decoder, this,
+                               std::move(*(sample_data_p.get())),
+                               std::move(hdl), std::move(msg_data->msg)));
         }
       }
       break;
@@ -152,8 +202,9 @@
       if (isonline) {
         // if online TODO(zhaoming) still not done
         std::vector<char> s(pcm_data, pcm_data + num_samples);
-        asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
-                                          std::move(s), std::move(hdl)));
+        asio::post(io_decoder_,
+                   std::bind(&WebSocketServer::do_decoder, this, std::move(s),
+                             std::move(hdl), std::move(msg_data->msg)));
       } else {
         // for offline, we add receive data to end of the sample data vector
         sample_data_p->insert(sample_data_p->end(), pcm_data,

--
Gitblit v1.9.1