From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/websocket/funasr-wss-client.cpp |  154 ++++++++++++++++++++++++++++++++++++++------------
 1 files changed, 116 insertions(+), 38 deletions(-)

diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 5330125..cdc5c44 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -5,14 +5,14 @@
 /* 2022-2023 by zhaomingwork */
 
 // client for websocket, support multiple threads
-// ./funasr-ws-client  --server-ip <string>
+// ./funasr-wss-client  --server-ip <string>
 //                     --port <string>
 //                     --wav-path <string>
 //                     [--thread-num <int>] 
 //                     [--is-ssl <int>]  [--]
 //                     [--version] [-h]
 // example:
-// ./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
+// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
 
 #define ASIO_STANDALONE 1
 #include <websocketpp/client.hpp>
@@ -20,6 +20,7 @@
 #include <websocketpp/config/asio_client.hpp>
 #include <fstream>
 #include <atomic>
+#include <thread>
 #include <glog/logging.h>
 
 #include "audio.h"
@@ -31,9 +32,9 @@
  */
 void WaitABit() {
     #ifdef WIN32
-        Sleep(1000);
+        Sleep(200);
     #else
-        sleep(1);
+        usleep(200);
     #endif
 }
 std::atomic<int> wav_index(0);
@@ -105,10 +106,12 @@
         const std::string& payload = msg->get_payload();
         switch (msg->get_opcode()) {
             case websocketpp::frame::opcode::text:
-				total_num=total_num+1;
-                LOG(INFO)<<total_num<<",on_message = " << payload;
-				if((total_num+1)==wav_index)
+				total_recv=total_recv+1;
+                LOG(INFO)<< "Thread: " << this_thread::get_id() <<", on_message = " << payload;
+                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv << " total_send=" <<total_send;
+				if(total_recv==total_send)
 				{
+                    LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client";
 					websocketpp::lib::error_code ec;
 					m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
 					if (ec){
@@ -119,7 +122,7 @@
     }
 
     // This method will block until the connection is complete  
-    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
         // Create a new connection to the given URI
         websocketpp::lib::error_code ec;
         typename websocketpp::client<T>::connection_ptr con =
@@ -140,12 +143,17 @@
         // Create a thread to run the ASIO io_service event loop
         websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                             &m_client);
+        bool send_hotword = true;
         while(true){
             int i = wav_index.fetch_add(1);
             if (i >= wav_list.size()) {
                 break;
             }
-            send_wav_data(wav_list[i], wav_ids[i]);
+            total_send += 1;
+            send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
+            if(send_hotword){
+                send_hotword = false;
+            }
         }
         WaitABit(); 
 
@@ -180,12 +188,13 @@
         m_done = true;
     }
     // send wav to server
-    void send_wav_data(string wav_path, string wav_id) {
+    void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
         uint64_t count = 0;
         std::stringstream val;
 
 		funasr::Audio audio(1);
         int32_t sampling_rate = 16000;
+        std::string wav_format = "pcm";
 		if(IsTargetFile(wav_path.c_str(), "wav")){
 			int32_t sampling_rate = -1;
 			if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
@@ -194,8 +203,9 @@
 			if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
 				return ;
 		}else{
-			printf("Wrong wav extension");
-			exit(-1);
+			wav_format = "others";
+            if (!audio.LoadOthers2Char(wav_path.c_str()))
+				return ;
 		}
 
         float* buff;
@@ -232,39 +242,87 @@
         jsonbegin["chunk_size"] = chunk_size;
         jsonbegin["chunk_interval"] = 10;
         jsonbegin["wav_name"] = wav_id;
+        jsonbegin["wav_format"] = wav_format;
         jsonbegin["is_speaking"] = true;
+        if(send_hotword){
+            LOG(INFO) << "hotwords: "<< hotwords;
+            jsonbegin["hotwords"] = hotwords;
+        }
         m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                       ec);
 
         // fetch wav data use asr engine api
-        while (audio.Fetch(buff, len, flag) > 0) {
-            short iArray[len];
+        if(wav_format == "pcm"){
+            while (audio.Fetch(buff, len, flag) > 0) {
+                short* iArray = new short[len];
+                for (size_t i = 0; i < len; ++i) {
+                iArray[i] = (short)(buff[i]*32768);
+                }
 
-            // convert float -1,1 to short -32768,32767
-            for (size_t i = 0; i < len; ++i) {
-              iArray[i] = (short)(buff[i] * 32767);
+                // send data to server
+                int offset = 0;
+                int block_size = 102400;
+                while(offset < len){
+                    int send_block = 0;
+                    if (offset + block_size <= len){
+                        send_block = block_size;
+                    }else{
+                        send_block = len - offset;
+                    }
+                    m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+                        websocketpp::frame::opcode::binary, ec);
+                    offset += send_block;
+                }
+
+                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len * sizeof(short);
+                // The most likely error that we will get is that the connection is
+                // not in the right state. Usually this means we tried to send a
+                // message to a connection that was closed or in the process of
+                // closing. While many errors here can be easily recovered from,
+                // in this simple example, we'll stop the data loop.
+                if (ec) {
+                m_client.get_alog().write(websocketpp::log::alevel::app,
+                                            "Send Error: " + ec.message());
+                break;
+                }
+                delete[] iArray;
+                // WaitABit();
             }
-            // send data to server
-            m_client.send(m_hdl, iArray, len * sizeof(short),
-                          websocketpp::frame::opcode::binary, ec);
-            LOG(INFO) << "sended data len=" << len * sizeof(short);
+        }else{
+            int offset = 0;
+            int block_size = 204800;
+            len = audio.GetSpeechLen();
+            char* others_buff = audio.GetSpeechChar();
+
+            while(offset < len){
+                int send_block = 0;
+                if (offset + block_size <= len){
+                    send_block = block_size;
+                }else{
+                    send_block = len - offset;
+                }
+                m_client.send(m_hdl, others_buff+offset, send_block,
+                    websocketpp::frame::opcode::binary, ec);
+                offset += send_block;
+            }
+
+            LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len;
             // The most likely error that we will get is that the connection is
             // not in the right state. Usually this means we tried to send a
             // message to a connection that was closed or in the process of
             // closing. While many errors here can be easily recovered from,
             // in this simple example, we'll stop the data loop.
             if (ec) {
-              m_client.get_alog().write(websocketpp::log::alevel::app,
+                m_client.get_alog().write(websocketpp::log::alevel::app,
                                         "Send Error: " + ec.message());
-              break;
             }
-            // WaitABit();
         }
+
         nlohmann::json jsonresult;
         jsonresult["is_speaking"] = false;
         m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                       ec);
-        // WaitABit();
+        std::this_thread::sleep_for(std::chrono::milliseconds(20));
     }
     websocketpp::client<T> m_client;
 
@@ -273,7 +331,8 @@
     websocketpp::lib::mutex m_lock;
     bool m_open;
     bool m_done;
-	int total_num=0;
+	int total_send=0;
+    int total_recv=0;
 };
 
 int main(int argc, char* argv[]) {
@@ -281,7 +340,7 @@
     google::InitGoogleLogging(argv[0]);
     FLAGS_logtostderr = true;
 
-    TCLAP::CmdLine cmd("funasr-ws-client", ' ', "1.0");
+    TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
     TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
                                            "127.0.0.1", "string");
     TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
@@ -293,12 +352,14 @@
     TCLAP::ValueArg<int> is_ssl_(
         "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", 
         false, 1, "int");
+    TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 闃块噷宸村反 杈炬懇闄�)", false, "", "string");
 
     cmd.add(server_ip_);
     cmd.add(port_);
     cmd.add(wav_path_);
     cmd.add(thread_num_);
     cmd.add(is_ssl_);
+    cmd.add(hotword_);
     cmd.parse(argc, argv);
 
     std::string server_ip = server_ip_.getValue();
@@ -315,15 +376,32 @@
         uri = "ws://" + server_ip + ":" + port;
     }
 
+    // read hotwords
+    std::string hotword = hotword_.getValue();
+    std::string hotwords_;
+
+    if(IsTargetFile(hotword, "txt")){
+        ifstream in(hotword);
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " <<  hotword;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            hotwords_ +=line+HOTWORD_SEP;
+        }
+        in.close();
+    }else{
+        hotwords_ = hotword;
+    }
+
+
     // read wav_path
     std::vector<string> wav_list;
     std::vector<string> wav_ids;
     string default_id = "wav_default_id";
-    if(IsTargetFile(wav_path, "wav") || IsTargetFile(wav_path, "pcm")){
-        wav_list.emplace_back(wav_path);
-        wav_ids.emplace_back(default_id);
-    }
-    else if(IsTargetFile(wav_path, "scp")){
+    if(IsTargetFile(wav_path, "scp")){
         ifstream in(wav_path);
         if (!in.is_open()) {
             printf("Failed to open scp file");
@@ -340,22 +418,22 @@
         }
         in.close();
     }else{
-        printf("Please check the wav extension!");
-        exit(-1);
+        wav_list.emplace_back(wav_path);
+        wav_ids.emplace_back(default_id);
     }
     
     for (size_t i = 0; i < threads_num; i++) {
-        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
           if (is_ssl == 1) {
             WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
             c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           } else {
             WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-            c.run(uri, wav_list, wav_ids);
+            c.run(uri, wav_list, wav_ids, hotwords_);
           }
         });
     }
@@ -363,4 +441,4 @@
     for (auto& t : client_threads) {
         t.join();
     }
-}
\ No newline at end of file
+}

--
Gitblit v1.9.1