From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/websocket/funasr-wss-client.cpp |   65 +++++++++++++++++++++++++-------
 1 files changed, 50 insertions(+), 15 deletions(-)

diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 231303f..cdc5c44 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -32,9 +32,9 @@
  */
 void WaitABit() {
     #ifdef WIN32
-        Sleep(1000);
+        Sleep(200);
     #else
-        sleep(1);
+        usleep(200);
     #endif
 }
 std::atomic<int> wav_index(0);
@@ -106,10 +106,12 @@
         const std::string& payload = msg->get_payload();
         switch (msg->get_opcode()) {
             case websocketpp::frame::opcode::text:
-				total_num=total_num+1;
-                LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
-				if((total_num+1)==wav_index)
+				total_recv=total_recv+1;
+                LOG(INFO)<< "Thread: " << this_thread::get_id() <<", on_message = " << payload;
+                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv << " total_send=" <<total_send;
+				if(total_recv==total_send)
 				{
+                    LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client";
 					websocketpp::lib::error_code ec;
 					m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
 					if (ec){
@@ -120,7 +122,7 @@
     }
 
     // This method will block until the connection is complete  
-    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
         // Create a new connection to the given URI
         websocketpp::lib::error_code ec;
         typename websocketpp::client<T>::connection_ptr con =
@@ -141,12 +143,17 @@
         // Create a thread to run the ASIO io_service event loop
         websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                             &m_client);
+        bool send_hotword = true;
         while(true){
             int i = wav_index.fetch_add(1);
             if (i >= wav_list.size()) {
                 break;
             }
-            send_wav_data(wav_list[i], wav_ids[i]);
+            total_send += 1;
+            send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
+            if(send_hotword){
+                send_hotword = false;
+            }
         }
         WaitABit(); 
 
@@ -181,7 +188,7 @@
         m_done = true;
     }
     // send wav to server
-    void send_wav_data(string wav_path, string wav_id) {
+    void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
         uint64_t count = 0;
         std::stringstream val;
 
@@ -237,6 +244,10 @@
         jsonbegin["wav_name"] = wav_id;
         jsonbegin["wav_format"] = wav_format;
         jsonbegin["is_speaking"] = true;
+        if(send_hotword){
+            LOG(INFO) << "hotwords: "<< hotwords;
+            jsonbegin["hotwords"] = hotwords;
+        }
         m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                       ec);
 
@@ -263,7 +274,7 @@
                     offset += send_block;
                 }
 
-                LOG(INFO) << "sended data len=" << len * sizeof(short);
+                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len * sizeof(short);
                 // The most likely error that we will get is that the connection is
                 // not in the right state. Usually this means we tried to send a
                 // message to a connection that was closed or in the process of
@@ -295,7 +306,7 @@
                 offset += send_block;
             }
 
-            LOG(INFO) << "sended data len=" << len;
+            LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len;
             // The most likely error that we will get is that the connection is
             // not in the right state. Usually this means we tried to send a
             // message to a connection that was closed or in the process of
@@ -311,7 +322,7 @@
         jsonresult["is_speaking"] = false;
         m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                       ec);
-        // WaitABit();
+        std::this_thread::sleep_for(std::chrono::milliseconds(20));
     }
     websocketpp::client<T> m_client;
 
@@ -320,7 +331,8 @@
     websocketpp::lib::mutex m_lock;
     bool m_open;
     bool m_done;
-	int total_num=0;
+	int total_send=0;
+    int total_recv=0;
 };
 
 int main(int argc, char* argv[]) {
@@ -340,12 +352,14 @@
     TCLAP::ValueArg<int> is_ssl_(
         "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", 
         false, 1, "int");
+    TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 闃块噷宸村反 杈炬懇闄�)", false, "", "string");
 
     cmd.add(server_ip_);
     cmd.add(port_);
     cmd.add(wav_path_);
     cmd.add(thread_num_);
     cmd.add(is_ssl_);
+    cmd.add(hotword_);
     cmd.parse(argc, argv);
 
     std::string server_ip = server_ip_.getValue();
@@ -361,6 +375,27 @@
     } else {
         uri = "ws://" + server_ip + ":" + port;
     }
+
+    // read hotwords
+    std::string hotword = hotword_.getValue();
+    std::string hotwords_;
+
+    if(IsTargetFile(hotword, "txt")){
+        ifstream in(hotword);
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " <<  hotword;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            hotwords_ +=line+HOTWORD_SEP;
+        }
+        in.close();
+    }else{
+        hotwords_ = hotword;
+    }
+
 
     // read wav_path
     std::vector<string> wav_list;
@@ -388,17 +423,17 @@
     }
     
     for (size_t i = 0; i < threads_num; i++) {
-        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
           if (is_ssl == 1) {
             WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
             c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           } else {
             WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           }
         });
     }

--
Gitblit v1.9.1