From e0fa63765bfb4a36bde7047c2a6066ca5a80e90f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 21 八月 2023 10:37:42 +0800
Subject: [PATCH] Dev hw (#878)
---
funasr/runtime/websocket/websocket-server.cpp | 201 ++++++++++++++++++++++++++++++++++---------------
1 files changed, 139 insertions(+), 62 deletions(-)
diff --git a/funasr/runtime/websocket/websocket-server.cpp b/funasr/runtime/websocket/websocket-server.cpp
index 59109b3..952acad 100644
--- a/funasr/runtime/websocket/websocket-server.cpp
+++ b/funasr/runtime/websocket/websocket-server.cpp
@@ -56,25 +56,37 @@
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(const std::vector<char>& buffer,
websocketpp::connection_hdl& hdl,
- const nlohmann::json& msg) {
+ websocketpp::lib::mutex& thread_lock,
+ std::vector<std::vector<float>> &hotwords_embedding,
+ std::string wav_name,
+ std::string wav_format) {
+ scoped_lock guard(thread_lock);
try {
int num_samples = buffer.size(); // the size of the buf
- if (!buffer.empty()) {
- // feed data to asr engine
- FUNASR_RESULT Result = FunOfflineInferBuffer(
- asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
+ if (!buffer.empty() && hotwords_embedding.size() >0 ) {
+ std::string asr_result;
+ std::string stamp_res;
+ try{
+ FUNASR_RESULT Result = FunOfflineInferBuffer(
+ asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, hotwords_embedding, 16000, wav_format);
- std::string asr_result =
- ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
- FunASRFreeResult(Result);
+ asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
+ stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
+ FunASRFreeResult(Result);
+ }catch (std::exception const& e) {
+ LOG(ERROR) << e.what();
+ return;
+ }
websocketpp::lib::error_code ec;
nlohmann::json jsonresult; // result json
jsonresult["text"] = asr_result; // put result in 'text'
jsonresult["mode"] = "offline";
-
- jsonresult["wav_name"] = msg["wav_name"];
+ if(stamp_res != ""){
+ jsonresult["timestamp"] = stamp_res;
+ }
+ jsonresult["wav_name"] = wav_name;
// send the json to client
if (is_ssl) {
@@ -86,11 +98,6 @@
}
LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
- if (!isonline) {
- // close the client if it is not online asr
- // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
- // fout.close();
- }
}
} catch (std::exception const& e) {
@@ -100,12 +107,11 @@
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock); // for threads safty
- check_and_clean_connection(); // remove closed connection
-
std::shared_ptr<FUNASR_MESSAGE> data_msg =
std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
// connection
data_msg->samples = std::make_shared<std::vector<char>>();
+ data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();
data_msg->msg = nlohmann::json::parse("{}");
data_msg->msg["wav_format"] = "pcm";
data_map.emplace(hdl, data_msg);
@@ -114,37 +120,88 @@
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock);
- data_map.erase(hdl); // remove data vector when connection is closed
+
+ std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ data_msg = it_data->second;
+ } else {
+ return;
+ }
+ unique_lock guard_decoder(*(data_msg->thread_lock));
+ data_msg->msg["is_eof"]=true;
+ guard_decoder.unlock();
+ // data_map.erase(hdl); // remove data vector when connection is closed
LOG(INFO) << "on_close, active connections: " << data_map.size();
}
-// remove closed connection
-void WebSocketServer::check_and_clean_connection() {
- std::vector<websocketpp::connection_hdl> to_remove; // remove list
- auto iter = data_map.begin();
- while (iter != data_map.end()) { // loop to find closed connection
- websocketpp::connection_hdl hdl = iter->first;
-
- if (is_ssl) {
- wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
- if (con->get_state() != 1) { // session::state::open ==1
- to_remove.push_back(hdl);
- }
- } else {
- server::connection_ptr con = server_->get_con_from_hdl(hdl);
- if (con->get_state() != 1) { // session::state::open ==1
- to_remove.push_back(hdl);
- }
- }
-
- iter++;
+void remove_hdl(
+ websocketpp::connection_hdl hdl,
+ std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
+ std::owner_less<websocketpp::connection_hdl>>& data_map) {
+
+ std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ data_msg = it_data->second;
+ } else {
+ return;
}
- for (auto hdl : to_remove) {
- data_map.erase(hdl);
- LOG(INFO)<< "remove one connection ";
+ unique_lock guard_decoder(*(data_msg->thread_lock));
+ if (data_msg->msg["is_eof"]==true) {
+ data_map.erase(hdl);
+ LOG(INFO) << "remove one connection";
+ }
+ guard_decoder.unlock();
+}
+
+void WebSocketServer::check_and_clean_connection() {
+ while(true){
+ std::this_thread::sleep_for(std::chrono::milliseconds(5000));
+ std::vector<websocketpp::connection_hdl> to_remove; // remove list
+ auto iter = data_map.begin();
+ while (iter != data_map.end()) { // loop to find closed connection
+ websocketpp::connection_hdl hdl = iter->first;
+ try{
+ if (is_ssl) {
+ wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ } else {
+ server::connection_ptr con = server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ }
+ }
+ catch (std::exception const &e)
+ {
+ // if connection is close, we set is_eof = true
+ std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ data_msg = it_data->second;
+ } else {
+ continue;
+ }
+ unique_lock guard_decoder(*(data_msg->thread_lock));
+ data_msg->msg["is_eof"]=true;
+ guard_decoder.unlock();
+ to_remove.push_back(hdl);
+ LOG(INFO)<<"connection is closed: "<<e.what();
+
+ }
+ iter++;
+ }
+ for (auto hdl : to_remove) {
+ remove_hdl(hdl, data_map);
+ //LOG(INFO) << "remove one connection ";
+ }
}
}
+
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
message_ptr msg) {
unique_lock lock(m_lock);
@@ -157,6 +214,7 @@
msg_data = it_data->second;
}
std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
+ std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
lock.unlock();
if (sample_data_p == nullptr) {
@@ -165,7 +223,7 @@
}
const std::string& payload = msg->get_payload(); // get msg type
-
+ unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text: {
nlohmann::json jsonresult = nlohmann::json::parse(payload);
@@ -175,24 +233,42 @@
if (jsonresult["wav_format"] != nullptr) {
msg_data->msg["wav_format"] = jsonresult["wav_format"];
}
+ if(msg_data->hotwords_embedding == NULL){
+ if (jsonresult["hotwords"] != nullptr) {
+ msg_data->msg["hotwords"] = jsonresult["hotwords"];
+ if (!msg_data->msg["hotwords"].empty()) {
+ std::string hw = msg_data->msg["hotwords"];
+ LOG(INFO)<<"hotwords: " << hw;
+ std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+ msg_data->hotwords_embedding =
+ std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+ }
+ }else{
+ std::string hw = "";
+ LOG(INFO)<<"hotwords: " << hw;
+ std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+ msg_data->hotwords_embedding =
+ std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+ }
+ }
if (jsonresult["is_speaking"] == false ||
jsonresult["is_finished"] == true) {
LOG(INFO) << "client done";
-
- if (isonline) {
- // do_close(ws);
- } else {
- // add padding to the end of the wav data
- // std::vector<short> padding(static_cast<short>(0.3 * 16000));
- // sample_data_p->insert(sample_data_p->end(), padding.data(),
- // padding.data() + padding.size());
- // for offline, send all receive data to decoder engine
- asio::post(io_decoder_,
- std::bind(&WebSocketServer::do_decoder, this,
- std::move(*(sample_data_p.get())),
- std::move(hdl), std::move(msg_data->msg)));
- }
+ // add padding to the end of the wav data
+ // std::vector<short> padding(static_cast<short>(0.3 * 16000));
+ // sample_data_p->insert(sample_data_p->end(), padding.data(),
+ // padding.data() + padding.size());
+ // for offline, send all receive data to decoder engine
+ std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
+ asio::post(io_decoder_,
+ std::bind(&WebSocketServer::do_decoder, this,
+ std::move(*(sample_data_p.get())),
+ std::move(hdl),
+ std::ref(*thread_lock_p),
+ std::move(hotwords_embedding_),
+ msg_data->msg["wav_name"],
+ msg_data->msg["wav_format"]));
}
break;
}
@@ -200,19 +276,15 @@
// recived binary data
const auto* pcm_data = static_cast<const char*>(payload.data());
int32_t num_samples = payload.size();
+ //LOG(INFO) << "recv binary num_samples " << num_samples;
if (isonline) {
- // if online TODO(zhaoming) still not done
- std::vector<char> s(pcm_data, pcm_data + num_samples);
- asio::post(io_decoder_,
- std::bind(&WebSocketServer::do_decoder, this, std::move(s),
- std::move(hdl), std::move(msg_data->msg)));
+ // TODO
} else {
// for offline, we add receive data to end of the sample data vector
sample_data_p->insert(sample_data_p->end(), pcm_data,
pcm_data + num_samples);
}
-
break;
}
default:
@@ -228,6 +300,11 @@
asr_hanlde = FunOfflineInit(model_path, thread_num);
LOG(INFO) << "model successfully inited";
+
+ LOG(INFO) << "initAsr run check_and_clean_connection";
+ std::thread clean_thread(&WebSocketServer::check_and_clean_connection,this);
+ clean_thread.detach();
+ LOG(INFO) << "initAsr run check_and_clean_connection finished";
} catch (const std::exception& e) {
LOG(INFO) << e.what();
--
Gitblit v1.9.1