From 61f00e84c2cc5f3e9eab8dba5c96ea8aa61e0721 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 06 十一月 2023 17:13:22 +0800
Subject: [PATCH] Merge pull request #1062 from alibaba-damo-academy/dev_lhn
---
funasr/runtime/websocket/bin/websocket-server-2pass.cpp | 40 ++++++++++++++++++++++++++++++++--------
1 files changed, 32 insertions(+), 8 deletions(-)
diff --git a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
index 107be40..a637471 100644
--- a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -15,7 +15,9 @@
#include <thread>
#include <utility>
#include <vector>
-#include <chrono>
+
+extern std::string hotwords;
+
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
websocketpp::connection_hdl hdl,
std::string& s_certfile,
@@ -354,7 +356,14 @@
unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text: {
- nlohmann::json jsonresult = nlohmann::json::parse(payload);
+ nlohmann::json jsonresult;
+ try{
+ jsonresult = nlohmann::json::parse(payload);
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ break;
+ }
if (jsonresult.contains("wav_name")) {
msg_data->msg["wav_name"] = jsonresult["wav_name"];
@@ -370,17 +379,26 @@
msg_data->msg["hotwords"] = jsonresult["hotwords"];
if (!msg_data->msg["hotwords"].empty()) {
std::string hw = msg_data->msg["hotwords"];
- LOG(INFO)<<"hotwords: " << hw;
- std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+ hw = hw + " " + hotwords;
+ LOG(INFO) << "hotwords: " << hw;
+ std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
msg_data->hotwords_embedding =
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
}
- }else{
+ } else {
+ if (hotwords.empty()) {
std::string hw = "";
LOG(INFO)<<"hotwords: " << hw;
std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
msg_data->hotwords_embedding =
std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+ }else {
+ std::string hw = hotwords;
+ LOG(INFO) << "hotwords: " << hw;
+ std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+ msg_data->hotwords_embedding =
+ std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+ }
}
}
if (jsonresult.contains("audio_fs")) {
@@ -390,9 +408,15 @@
if (msg_data->tpass_online_handle == NULL) {
std::vector<int> chunk_size_vec =
jsonresult["chunk_size"].get<std::vector<int>>();
- FUNASR_HANDLE tpass_online_handle =
- FunTpassOnlineInit(tpass_handle, chunk_size_vec);
- msg_data->tpass_online_handle = tpass_online_handle;
+ // check chunk_size_vec
+ if(chunk_size_vec.size() == 3 && chunk_size_vec[1] != 0){
+ FUNASR_HANDLE tpass_online_handle =
+ FunTpassOnlineInit(tpass_handle, chunk_size_vec);
+ msg_data->tpass_online_handle = tpass_online_handle;
+ }else{
+ LOG(ERROR) << "Wrong chunk_size!";
+ break;
+ }
}
}
if (jsonresult.contains("itn")) {
--
Gitblit v1.9.1