From e0fa63765bfb4a36bde7047c2a6066ca5a80e90f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 21 八月 2023 10:37:42 +0800
Subject: [PATCH] Dev hw (#878)

---
 funasr/runtime/websocket/funasr-wss-client.cpp |   51 ++++++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 231303f..7a93735 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -32,9 +32,9 @@
  */
 void WaitABit() {
     #ifdef WIN32
-        Sleep(1000);
+        Sleep(500);
     #else
-        sleep(1);
+        usleep(500);
     #endif
 }
 std::atomic<int> wav_index(0);
@@ -108,8 +108,10 @@
             case websocketpp::frame::opcode::text:
 				total_num=total_num+1;
                 LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
+                LOG(INFO) << "total_num=" << total_num << " wav_index=" <<wav_index;
 				if((total_num+1)==wav_index)
 				{
+                    LOG(INFO) << "close client";
 					websocketpp::lib::error_code ec;
 					m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
 					if (ec){
@@ -120,7 +122,7 @@
     }
 
     // This method will block until the connection is complete  
-    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
         // Create a new connection to the given URI
         websocketpp::lib::error_code ec;
         typename websocketpp::client<T>::connection_ptr con =
@@ -141,12 +143,16 @@
         // Create a thread to run the ASIO io_service event loop
         websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                             &m_client);
+        bool send_hotword = true;
         while(true){
             int i = wav_index.fetch_add(1);
             if (i >= wav_list.size()) {
                 break;
             }
-            send_wav_data(wav_list[i], wav_ids[i]);
+            send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
+            if(send_hotword){
+                send_hotword = false;
+            }
         }
         WaitABit(); 
 
@@ -181,7 +187,7 @@
         m_done = true;
     }
     // send wav to server
-    void send_wav_data(string wav_path, string wav_id) {
+    void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
         uint64_t count = 0;
         std::stringstream val;
 
@@ -237,6 +243,10 @@
         jsonbegin["wav_name"] = wav_id;
         jsonbegin["wav_format"] = wav_format;
         jsonbegin["is_speaking"] = true;
+        if(send_hotword){
+            LOG(INFO) << "hotwords: "<< hotwords;
+            jsonbegin["hotwords"] = hotwords;
+        }
         m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                       ec);
 
@@ -311,7 +321,7 @@
         jsonresult["is_speaking"] = false;
         m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                       ec);
-        // WaitABit();
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
     }
     websocketpp::client<T> m_client;
 
@@ -340,12 +350,14 @@
     TCLAP::ValueArg<int> is_ssl_(
         "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", 
         false, 1, "int");
+    TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 闃块噷宸村反 杈炬懇闄�)", false, "", "string");
 
     cmd.add(server_ip_);
     cmd.add(port_);
     cmd.add(wav_path_);
     cmd.add(thread_num_);
     cmd.add(is_ssl_);
+    cmd.add(hotword_);
     cmd.parse(argc, argv);
 
     std::string server_ip = server_ip_.getValue();
@@ -361,6 +373,27 @@
     } else {
         uri = "ws://" + server_ip + ":" + port;
     }
+
+    // read hotwords
+    std::string hotword = hotword_.getValue();
+    std::string hotwords_;
+
+    if(IsTargetFile(hotword, "txt")){
+        ifstream in(hotword);
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " <<  hotword;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            hotwords_ +=line+HOTWORD_SEP;
+        }
+        in.close();
+    }else{
+        hotwords_ = hotword;
+    }
+
 
     // read wav_path
     std::vector<string> wav_list;
@@ -388,17 +421,17 @@
     }
     
     for (size_t i = 0; i < threads_num; i++) {
-        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
           if (is_ssl == 1) {
             WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
             c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           } else {
             WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           }
         });
     }

--
Gitblit v1.9.1