游雁
2023-08-30 c2e4e3c2e9be855277d9f4fa9cd0544892ff829a
funasr/runtime/websocket/funasr-wss-client.cpp
@@ -32,9 +32,9 @@
 */
void WaitABit() {
    #ifdef WIN32
        Sleep(1000);
        Sleep(200);
    #else
        sleep(1);
        usleep(200);
    #endif
}
std::atomic<int> wav_index(0);
@@ -106,10 +106,12 @@
        const std::string& payload = msg->get_payload();
        switch (msg->get_opcode()) {
            case websocketpp::frame::opcode::text:
            total_num=total_num+1;
                LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
            if((total_num+1)==wav_index)
            total_recv=total_recv+1;
                LOG(INFO)<< "Thread: " << this_thread::get_id() <<", on_message = " << payload;
                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv << " total_send=" <<total_send;
            if(total_recv==total_send)
            {
                    LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client";
               websocketpp::lib::error_code ec;
               m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
               if (ec){
@@ -120,7 +122,7 @@
    }
    // This method will block until the connection is complete  
    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
        // Create a new connection to the given URI
        websocketpp::lib::error_code ec;
        typename websocketpp::client<T>::connection_ptr con =
@@ -141,12 +143,17 @@
        // Create a thread to run the ASIO io_service event loop
        websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                            &m_client);
        bool send_hotword = true;
        while(true){
            int i = wav_index.fetch_add(1);
            if (i >= wav_list.size()) {
                break;
            }
            send_wav_data(wav_list[i], wav_ids[i]);
            total_send += 1;
            send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
            if(send_hotword){
                send_hotword = false;
            }
        }
        WaitABit(); 
@@ -181,7 +188,7 @@
        m_done = true;
    }
    // send wav to server
    void send_wav_data(string wav_path, string wav_id) {
    void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
        uint64_t count = 0;
        std::stringstream val;
@@ -237,6 +244,10 @@
        jsonbegin["wav_name"] = wav_id;
        jsonbegin["wav_format"] = wav_format;
        jsonbegin["is_speaking"] = true;
        if(send_hotword){
            LOG(INFO) << "hotwords: "<< hotwords;
            jsonbegin["hotwords"] = hotwords;
        }
        m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                      ec);
@@ -263,7 +274,7 @@
                    offset += send_block;
                }
                LOG(INFO) << "sended data len=" << len * sizeof(short);
                LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len * sizeof(short);
                // The most likely error that we will get is that the connection is
                // not in the right state. Usually this means we tried to send a
                // message to a connection that was closed or in the process of
@@ -295,7 +306,7 @@
                offset += send_block;
            }
            LOG(INFO) << "sended data len=" << len;
            LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len;
            // The most likely error that we will get is that the connection is
            // not in the right state. Usually this means we tried to send a
            // message to a connection that was closed or in the process of
@@ -311,7 +322,7 @@
        jsonresult["is_speaking"] = false;
        m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                      ec);
        // WaitABit();
        std::this_thread::sleep_for(std::chrono::milliseconds(20));
    }
    websocketpp::client<T> m_client;
@@ -320,7 +331,8 @@
    websocketpp::lib::mutex m_lock;
    bool m_open;
    bool m_done;
   int total_num=0;
   int total_send=0;
    int total_recv=0;
};
int main(int argc, char* argv[]) {
@@ -340,12 +352,14 @@
    TCLAP::ValueArg<int> is_ssl_(
        "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", 
        false, 1, "int");
    TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)", false, "", "string");
    cmd.add(server_ip_);
    cmd.add(port_);
    cmd.add(wav_path_);
    cmd.add(thread_num_);
    cmd.add(is_ssl_);
    cmd.add(hotword_);
    cmd.parse(argc, argv);
    std::string server_ip = server_ip_.getValue();
@@ -361,6 +375,27 @@
    } else {
        uri = "ws://" + server_ip + ":" + port;
    }
    // read hotwords
    std::string hotword = hotword_.getValue();
    std::string hotwords_;
    if(IsTargetFile(hotword, "txt")){
        ifstream in(hotword);
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " <<  hotword;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            hotwords_ +=line+HOTWORD_SEP;
        }
        in.close();
    }else{
        hotwords_ = hotword;
    }
    // read wav_path
    std::vector<string> wav_list;
@@ -388,17 +423,17 @@
    }
    
    for (size_t i = 0; i < threads_num; i++) {
        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
          if (is_ssl == 1) {
            WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
            c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
            c.run(uri, wav_list, wav_ids);
            c.run(uri, wav_list, wav_ids, hotwords_);
          } else {
            WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
            c.run(uri, wav_list, wav_ids);
            c.run(uri, wav_list, wav_ids, hotwords_);
          }
        });
    }