From e0fa63765bfb4a36bde7047c2a6066ca5a80e90f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 21 八月 2023 10:37:42 +0800
Subject: [PATCH] Dev hw (#878)
---
funasr/runtime/websocket/funasr-wss-client.cpp | 51 ++++++++++++++++++++++++++++++++++++++++++---------
1 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 231303f..7a93735 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -32,9 +32,9 @@
*/
void WaitABit() {
#ifdef WIN32
- Sleep(1000);
+ Sleep(500);
#else
- sleep(1);
+ usleep(500);
#endif
}
std::atomic<int> wav_index(0);
@@ -108,8 +108,10 @@
case websocketpp::frame::opcode::text:
total_num=total_num+1;
LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
+ LOG(INFO) << "total_num=" << total_num << " wav_index=" <<wav_index;
if((total_num+1)==wav_index)
{
+ LOG(INFO) << "close client";
websocketpp::lib::error_code ec;
m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
if (ec){
@@ -120,7 +122,7 @@
}
// This method will block until the connection is complete
- void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+ void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@@ -141,12 +143,16 @@
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
+ bool send_hotword = true;
while(true){
int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
- send_wav_data(wav_list[i], wav_ids[i]);
+ send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
+ if(send_hotword){
+ send_hotword = false;
+ }
}
WaitABit();
@@ -181,7 +187,7 @@
m_done = true;
}
// send wav to server
- void send_wav_data(string wav_path, string wav_id) {
+ void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
uint64_t count = 0;
std::stringstream val;
@@ -237,6 +243,10 @@
jsonbegin["wav_name"] = wav_id;
jsonbegin["wav_format"] = wav_format;
jsonbegin["is_speaking"] = true;
+ if(send_hotword){
+ LOG(INFO) << "hotwords: "<< hotwords;
+ jsonbegin["hotwords"] = hotwords;
+ }
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
ec);
@@ -311,7 +321,7 @@
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
ec);
- // WaitABit();
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
websocketpp::client<T> m_client;
@@ -340,12 +350,14 @@
TCLAP::ValueArg<int> is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
false, 1, "int");
+ TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 闃块噷宸村反 杈炬懇闄�)", false, "", "string");
cmd.add(server_ip_);
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
+ cmd.add(hotword_);
cmd.parse(argc, argv);
std::string server_ip = server_ip_.getValue();
@@ -361,6 +373,27 @@
} else {
uri = "ws://" + server_ip + ":" + port;
}
+
+ // read hotwords
+ std::string hotword = hotword_.getValue();
+ std::string hotwords_;
+
+ if(IsTargetFile(hotword, "txt")){
+ ifstream in(hotword);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << hotword;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ hotwords_ +=line+HOTWORD_SEP;
+ }
+ in.close();
+ }else{
+ hotwords_ = hotword;
+ }
+
// read wav_path
std::vector<string> wav_list;
@@ -388,17 +421,17 @@
}
for (size_t i = 0; i < threads_num; i++) {
- client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+ client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.run(uri, wav_list, wav_ids);
+ c.run(uri, wav_list, wav_ids, hotwords_);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, wav_list, wav_ids);
+ c.run(uri, wav_list, wav_ids, hotwords_);
}
});
}
--
Gitblit v1.9.1