From 26a2a232a94c4a729733d83e8175a16e3f8db481 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 23 十月 2023 16:42:43 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/websocket/bin/websocket-server.cpp |   27 +++++++++++++++++++++++----
 1 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/funasr/runtime/websocket/bin/websocket-server.cpp b/funasr/runtime/websocket/bin/websocket-server.cpp
index da1ffa5..134f5fa 100644
--- a/funasr/runtime/websocket/bin/websocket-server.cpp
+++ b/funasr/runtime/websocket/bin/websocket-server.cpp
@@ -16,6 +16,8 @@
 #include <utility>
 #include <vector>
 
+extern std::string hotwords;
+
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
                                          std::string& s_certfile,
@@ -254,7 +256,15 @@
   unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
   switch (msg->get_opcode()) {
     case websocketpp::frame::opcode::text: {
-      nlohmann::json jsonresult = nlohmann::json::parse(payload);
+      nlohmann::json jsonresult;
+      try{
+        jsonresult = nlohmann::json::parse(payload);
+      }catch (std::exception const &e)
+      {
+        LOG(ERROR)<<e.what();
+        break;
+      }
+
       if (jsonresult["wav_name"] != nullptr) {
         msg_data->msg["wav_name"] = jsonresult["wav_name"];
       }
@@ -266,17 +276,26 @@
           msg_data->msg["hotwords"] = jsonresult["hotwords"];
           if (!msg_data->msg["hotwords"].empty()) {
             std::string hw = msg_data->msg["hotwords"];
-            LOG(INFO)<<"hotwords: " << hw;
-            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            hw = hw + " " + hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
           }
-        }else{
+        } else {
+          if (hotwords.empty()) {
             std::string hw = "";
             LOG(INFO)<<"hotwords: " << hw;
             std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }else {
+            std::string hw = hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            msg_data->hotwords_embedding =
+                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }
         }
       }
       if (jsonresult.contains("audio_fs")) {

--
Gitblit v1.9.1