From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/runtime/websocket/funasr-wss-client.cpp | 154 ++++++++++++++++++++++++++++++++++++++------------
1 files changed, 116 insertions(+), 38 deletions(-)
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index 5330125..cdc5c44 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -5,14 +5,14 @@
/* 2022-2023 by zhaomingwork */
// client for websocket, support multiple threads
-// ./funasr-ws-client --server-ip <string>
+// ./funasr-wss-client --server-ip <string>
// --port <string>
// --wav-path <string>
// [--thread-num <int>]
// [--is-ssl <int>] [--]
// [--version] [-h]
// example:
-// ./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
+// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
@@ -20,6 +20,7 @@
#include <websocketpp/config/asio_client.hpp>
#include <fstream>
#include <atomic>
+#include <thread>
#include <glog/logging.h>
#include "audio.h"
@@ -31,9 +32,9 @@
*/
void WaitABit() {
#ifdef WIN32
- Sleep(1000);
+ Sleep(200);
#else
- sleep(1);
+ usleep(200);
#endif
}
std::atomic<int> wav_index(0);
@@ -105,10 +106,12 @@
const std::string& payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
- total_num=total_num+1;
- LOG(INFO)<<total_num<<",on_message = " << payload;
- if((total_num+1)==wav_index)
+ total_recv=total_recv+1;
+ LOG(INFO)<< "Thread: " << this_thread::get_id() <<", on_message = " << payload;
+ LOG(INFO)<< "Thread: " << this_thread::get_id() << ", total_recv=" << total_recv << " total_send=" <<total_send;
+ if(total_recv==total_send)
{
+ LOG(INFO)<< "Thread: " << this_thread::get_id() << ", close client";
websocketpp::lib::error_code ec;
m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
if (ec){
@@ -119,7 +122,7 @@
}
// This method will block until the connection is complete
- void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
+ void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@@ -140,12 +143,17 @@
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
+ bool send_hotword = true;
while(true){
int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
- send_wav_data(wav_list[i], wav_ids[i]);
+ total_send += 1;
+ send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
+ if(send_hotword){
+ send_hotword = false;
+ }
}
WaitABit();
@@ -180,12 +188,13 @@
m_done = true;
}
// send wav to server
- void send_wav_data(string wav_path, string wav_id) {
+ void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
uint64_t count = 0;
std::stringstream val;
funasr::Audio audio(1);
int32_t sampling_rate = 16000;
+ std::string wav_format = "pcm";
if(IsTargetFile(wav_path.c_str(), "wav")){
int32_t sampling_rate = -1;
if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
@@ -194,8 +203,9 @@
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
return ;
}else{
- printf("Wrong wav extension");
- exit(-1);
+ wav_format = "others";
+ if (!audio.LoadOthers2Char(wav_path.c_str()))
+ return ;
}
float* buff;
@@ -232,39 +242,87 @@
jsonbegin["chunk_size"] = chunk_size;
jsonbegin["chunk_interval"] = 10;
jsonbegin["wav_name"] = wav_id;
+ jsonbegin["wav_format"] = wav_format;
jsonbegin["is_speaking"] = true;
+ if(send_hotword){
+ LOG(INFO) << "hotwords: "<< hotwords;
+ jsonbegin["hotwords"] = hotwords;
+ }
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
ec);
// fetch wav data use asr engine api
- while (audio.Fetch(buff, len, flag) > 0) {
- short iArray[len];
+ if(wav_format == "pcm"){
+ while (audio.Fetch(buff, len, flag) > 0) {
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buff[i]*32768);
+ }
- // convert float -1,1 to short -32768,32767
- for (size_t i = 0; i < len; ++i) {
- iArray[i] = (short)(buff[i] * 32767);
+ // send data to server
+ int offset = 0;
+ int block_size = 102400;
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len * sizeof(short);
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ break;
+ }
+ delete[] iArray;
+ // WaitABit();
}
- // send data to server
- m_client.send(m_hdl, iArray, len * sizeof(short),
- websocketpp::frame::opcode::binary, ec);
- LOG(INFO) << "sended data len=" << len * sizeof(short);
+ }else{
+ int offset = 0;
+ int block_size = 204800;
+ len = audio.GetSpeechLen();
+ char* others_buff = audio.GetSpeechChar();
+
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, others_buff+offset, send_block,
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO)<< "Thread: " << this_thread::get_id() << ", sended data len=" << len;
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
+ m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
- break;
}
- // WaitABit();
}
+
nlohmann::json jsonresult;
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
ec);
- // WaitABit();
+ std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
websocketpp::client<T> m_client;
@@ -273,7 +331,8 @@
websocketpp::lib::mutex m_lock;
bool m_open;
bool m_done;
- int total_num=0;
+ int total_send=0;
+ int total_recv=0;
};
int main(int argc, char* argv[]) {
@@ -281,7 +340,7 @@
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
- TCLAP::CmdLine cmd("funasr-ws-client", ' ', "1.0");
+ TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
"127.0.0.1", "string");
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
@@ -293,12 +352,14 @@
TCLAP::ValueArg<int> is_ssl_(
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
false, 1, "int");
+ TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 闃块噷宸村反 杈炬懇闄�)", false, "", "string");
cmd.add(server_ip_);
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
+ cmd.add(hotword_);
cmd.parse(argc, argv);
std::string server_ip = server_ip_.getValue();
@@ -315,15 +376,32 @@
uri = "ws://" + server_ip + ":" + port;
}
+ // read hotwords
+ std::string hotword = hotword_.getValue();
+ std::string hotwords_;
+
+ if(IsTargetFile(hotword, "txt")){
+ ifstream in(hotword);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << hotword;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ hotwords_ +=line+HOTWORD_SEP;
+ }
+ in.close();
+ }else{
+ hotwords_ = hotword;
+ }
+
+
// read wav_path
std::vector<string> wav_list;
std::vector<string> wav_ids;
string default_id = "wav_default_id";
- if(IsTargetFile(wav_path, "wav") || IsTargetFile(wav_path, "pcm")){
- wav_list.emplace_back(wav_path);
- wav_ids.emplace_back(default_id);
- }
- else if(IsTargetFile(wav_path, "scp")){
+ if(IsTargetFile(wav_path, "scp")){
ifstream in(wav_path);
if (!in.is_open()) {
printf("Failed to open scp file");
@@ -340,22 +418,22 @@
}
in.close();
}else{
- printf("Please check the wav extension!");
- exit(-1);
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
}
for (size_t i = 0; i < threads_num; i++) {
- client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
+ client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.run(uri, wav_list, wav_ids);
+ c.run(uri, wav_list, wav_ids, hotwords_);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, wav_list, wav_ids);
+ c.run(uri, wav_list, wav_ids, hotwords_);
}
});
}
@@ -363,4 +441,4 @@
for (auto& t : client_threads) {
t.join();
}
-}
\ No newline at end of file
+}
--
Gitblit v1.9.1