From f03a604204bbe0c79e53b01237a37e88683938c6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 五月 2023 00:17:32 +0800
Subject: [PATCH] Merge pull request #505 from zhaomingwork/cpp-python-websocket-compatible
---
funasr/runtime/websocket/websocketsrv.cpp | 60 ++++++++++----
funasr/runtime/python/websocket/ws_client.py | 56 ++++++++++---
funasr/runtime/websocket/websocketsrv.h | 11 ++
funasr/runtime/python/websocket/ws_server_online.py | 41 ++++++---
funasr/runtime/websocket/websocketclient.cpp | 20 ++++
5 files changed, 139 insertions(+), 49 deletions(-)
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index a4a6d9f..7ae44df 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -84,18 +84,20 @@
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
+
+ message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+ voices.put(message)
is_speaking = True
while True:
data = stream.read(CHUNK)
- data = data.decode('ISO-8859-1')
- message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
+ message = data
voices.put(message)
await asyncio.sleep(0.005)
-async def record_from_scp():
+async def record_from_scp(chunk_begin,chunk_size):
import wave
global voices
is_finished = False
@@ -104,6 +106,8 @@
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
+ if chunk_size>0:
+ wavs=wavs[chunk_begin:chunk_begin+chunk_size]
for wav in wavs:
wav_splits = wav.strip().split()
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
@@ -122,15 +126,21 @@
stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
chunk_num = (len(audio_bytes)-1)//stride + 1
# print(stride)
+
+ # send first time
+ message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
+ voices.put(message)
is_speaking = True
for i in range(chunk_num):
- if i == chunk_num-1:
- is_speaking = False
+
beg = i*stride
data = audio_bytes[beg:beg+stride]
- data = data.decode('ISO-8859-1')
- message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished, "wav_name": wav_name})
+ message = data
voices.put(message)
+ if i == chunk_num-1:
+ is_speaking = False
+ message = json.dumps({"is_speaking": is_speaking})
+ voices.put(message)
# print("data_chunk: ", len(data_chunk))
# print(voices.qsize())
sleep_duration = 0.001 if args.send_without_sleep else 60*args.chunk_size[1]/args.chunk_interval/1000
@@ -213,27 +223,47 @@
traceback.print_exc()
exit(0)
-async def ws_client(id):
+async def ws_client(id,chunk_begin,chunk_size):
global websocket
uri = "ws://{}:{}".format(args.host, args.port)
async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
if args.audio_in is not None:
- task = asyncio.create_task(record_from_scp())
+ task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
else:
task = asyncio.create_task(record_microphone())
task2 = asyncio.create_task(ws_send())
task3 = asyncio.create_task(message(id))
await asyncio.gather(task, task2, task3)
-def one_thread(id):
- asyncio.get_event_loop().run_until_complete(ws_client(id))
+def one_thread(id,chunk_begin,chunk_size):
+ asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
asyncio.get_event_loop().run_forever()
if __name__ == '__main__':
+ # calculate the number of wavs for each preocess
+ if args.audio_in.endswith(".scp"):
+ f_scp = open(args.audio_in)
+ wavs = f_scp.readlines()
+ else:
+ wavs = [args.audio_in]
+ total_len=len(wavs)
+ if total_len>=args.test_thread_num:
+ chunk_size=int((total_len)/args.test_thread_num)
+ remain_wavs=total_len-chunk_size*args.test_thread_num
+ else:
+ chunk_size=0
+
process_list = []
- for i in range(args.test_thread_num):
- p = Process(target=one_thread,args=(i,))
+ chunk_begin=0
+ for i in range(args.test_thread_num):
+ now_chunk_size= chunk_size
+ if remain_wavs>0:
+ now_chunk_size=chunk_size+1
+ remain_wavs=remain_wavs-1
+ # process i handle wavs at chunk_begin and size of now_chunk_size
+ p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
+ chunk_begin=chunk_begin+now_chunk_size
p.start()
process_list.append(p)
diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py
index 3c0fb16..44edf98 100644
--- a/funasr/runtime/python/websocket/ws_server_online.py
+++ b/funasr/runtime/python/websocket/ws_server_online.py
@@ -41,25 +41,37 @@
global websocket_users
websocket_users.add(websocket)
websocket.param_dict_asr_online = {"cache": dict()}
-
+ print("new user connected",flush=True)
try:
async for message in websocket:
- message = json.loads(message)
- is_finished = message["is_finished"]
- if not is_finished:
- audio = bytes(message['audio'], 'ISO-8859-1')
-
- is_speaking = message["is_speaking"]
- websocket.param_dict_asr_online["is_final"] = not is_speaking
- websocket.wav_name = message.get("wav_name", "demo")
- websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
-
- frames_asr_online.append(audio)
- if len(frames_asr_online) % message["chunk_interval"] == 0 or not is_speaking:
+
+
+ if isinstance(message,str):
+ messagejson = json.loads(message)
+
+ if "is_speaking" in messagejson:
+ websocket.is_speaking = messagejson["is_speaking"]
+ websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
+ if "is_finished" in messagejson:
+ websocket.is_speaking = False
+ websocket.param_dict_asr_online["is_final"] = True
+ if "chunk_interval" in messagejson:
+ websocket.chunk_interval=messagejson["chunk_interval"]
+ if "wav_name" in messagejson:
+ websocket.wav_name = messagejson.get("wav_name", "demo")
+ if "chunk_size" in messagejson:
+ websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
+ # if has bytes in buffer or message is bytes
+ if len(frames_asr_online)>0 or not isinstance(message,str):
+ if not isinstance(message,str):
+ frames_asr_online.append(message)
+ if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
audio_in = b"".join(frames_asr_online)
+ if not websocket.is_speaking:
+ #padding 0.5s at end gurantee that asr engine can fire out last word
+ audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
await async_asr_online(websocket,audio_in)
frames_asr_online = []
-
except websockets.ConnectionClosed:
@@ -69,6 +81,7 @@
print("InvalidState...")
except Exception as e:
print("Exception:", e)
+
async def async_asr_online(websocket,audio_in):
if len(audio_in) > 0:
diff --git a/funasr/runtime/websocket/websocketclient.cpp b/funasr/runtime/websocket/websocketclient.cpp
index 3ab4e99..078fc5a 100644
--- a/funasr/runtime/websocket/websocketclient.cpp
+++ b/funasr/runtime/websocket/websocketclient.cpp
@@ -13,6 +13,7 @@
#include <websocketpp/config/asio_no_tls_client.hpp>
#include "audio.h"
+#include "nlohmann/json.hpp"
/**
* Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -156,6 +157,19 @@
}
}
websocketpp::lib::error_code ec;
+
+ nlohmann::json jsonbegin;
+ nlohmann::json chunk_size = nlohmann::json::array();
+ chunk_size.push_back(5);
+ chunk_size.push_back(0);
+ chunk_size.push_back(5);
+ jsonbegin["chunk_size"] = chunk_size;
+ jsonbegin["chunk_interval"] = 10;
+ jsonbegin["wav_name"] = "damo";
+ jsonbegin["is_speaking"] = true;
+ m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+ ec);
+
// fetch wav data use asr engine api
while (audio.Fetch(buff, len, flag) > 0) {
short iArray[len];
@@ -181,8 +195,10 @@
wait_a_bit();
}
-
- m_client.send(m_hdl, "Done", websocketpp::frame::opcode::text, ec);
+ nlohmann::json jsonresult;
+ jsonresult["is_speaking"] = false;
+ m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
wait_a_bit();
}
diff --git a/funasr/runtime/websocket/websocketsrv.cpp b/funasr/runtime/websocket/websocketsrv.cpp
index 1a6adbf..598ad3d 100644
--- a/funasr/runtime/websocket/websocketsrv.cpp
+++ b/funasr/runtime/websocket/websocketsrv.cpp
@@ -34,6 +34,14 @@
websocketpp::lib::error_code ec;
nlohmann::json jsonresult; // result json
jsonresult["text"] = asr_result; // put result in 'text'
+ jsonresult["mode"] = "offline";
+ std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ msg_data = it_data->second;
+ }
+
+ jsonresult["wav_name"] = msg_data->msg["wav_name"];
// send the json to client
server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
@@ -43,7 +51,7 @@
<< ",result json=" << jsonresult.dump() << std::endl;
if (!isonline) {
// close the client if it is not online asr
- server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
+ // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
// fout.close();
}
}
@@ -56,25 +64,28 @@
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock); // for threads safty
check_and_clean_connection(); // remove closed connection
- sample_map.emplace(
- hdl, std::make_shared<std::vector<char>>()); // put a new data vector for
- // new connection
- std::cout << "on_open, active connections: " << sample_map.size()
- << std::endl;
+
+ std::shared_ptr<FUNASR_MESSAGE> data_msg =
+ std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
+ // connection
+ data_msg->samples = std::make_shared<std::vector<char>>();
+ data_msg->msg = nlohmann::json::parse("{}");
+ data_map.emplace(hdl, data_msg);
+ std::cout << "on_open, active connections: " << data_map.size() << std::endl;
}
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock);
- sample_map.erase(hdl); // remove data vector when connection is closed
- std::cout << "on_close, active connections: " << sample_map.size()
- << std::endl;
+ data_map.erase(hdl); // remove data vector when connection is closed
+
+ std::cout << "on_close, active connections: " << data_map.size() << std::endl;
}
// remove closed connection
void WebSocketServer::check_and_clean_connection() {
std::vector<websocketpp::connection_hdl> to_remove; // remove list
- auto iter = sample_map.begin();
- while (iter != sample_map.end()) { // loop to find closed connection
+ auto iter = data_map.begin();
+ while (iter != data_map.end()) { // loop to find closed connection
websocketpp::connection_hdl hdl = iter->first;
server::connection_ptr con = server_->get_con_from_hdl(hdl);
if (con->get_state() != 1) { // session::state::open ==1
@@ -83,7 +94,7 @@
iter++;
}
for (auto hdl : to_remove) {
- sample_map.erase(hdl);
+ data_map.erase(hdl);
std::cout << "remove one connection " << std::endl;
}
}
@@ -91,12 +102,15 @@
message_ptr msg) {
unique_lock lock(m_lock);
// find the sample data vector according to one connection
- std::shared_ptr<std::vector<char>> sample_data_p = nullptr;
- auto it = sample_map.find(hdl);
- if (it != sample_map.end()) {
- sample_data_p = it->second;
+ std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ msg_data = it_data->second;
}
+ std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
+
lock.unlock();
if (sample_data_p == nullptr) {
std::cout << "error when fetch sample data vector" << std::endl;
@@ -106,13 +120,22 @@
const std::string& payload = msg->get_payload(); // get msg type
switch (msg->get_opcode()) {
- case websocketpp::frame::opcode::text:
- if (payload == "Done") {
+ case websocketpp::frame::opcode::text: {
+ nlohmann::json jsonresult = nlohmann::json::parse(payload);
+ if (jsonresult["wav_name"] != nullptr) {
+ msg_data->msg["wav_name"] = jsonresult["wav_name"];
+ }
+ if (jsonresult["is_speaking"] == false ||
+ jsonresult["is_finished"] == true) {
std::cout << "client done" << std::endl;
if (isonline) {
// do_close(ws);
} else {
+ // add padding to the end of the wav data
+ std::vector<short> padding(static_cast<short>(0.3 * 16000));
+ sample_data_p->insert(sample_data_p->end(), padding.data(),
+ padding.data() + padding.size());
// for offline, send all receive data to decoder engine
asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
std::move(*(sample_data_p.get())),
@@ -120,6 +143,7 @@
}
}
break;
+ }
case websocketpp::frame::opcode::binary: {
// recived binary data
const auto* pcm_data = static_cast<const char*>(payload.data());
diff --git a/funasr/runtime/websocket/websocketsrv.h b/funasr/runtime/websocket/websocketsrv.h
index e484724..1899f57 100644
--- a/funasr/runtime/websocket/websocketsrv.h
+++ b/funasr/runtime/websocket/websocketsrv.h
@@ -46,6 +46,11 @@
float snippet_time;
} FUNASR_RECOG_RESULT;
+typedef struct {
+ nlohmann::json msg;
+ std::shared_ptr<std::vector<char>> samples;
+} FUNASR_MESSAGE;
+
class WebSocketServer {
public:
WebSocketServer(asio::io_context& io_decoder, server* server_)
@@ -84,9 +89,11 @@
// use map to keep the received samples data from one connection in offline
// engine. if for online engline, a data struct is needed(TODO)
- std::map<websocketpp::connection_hdl, std::shared_ptr<std::vector<char>>,
+
+
+ std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
std::owner_less<websocketpp::connection_hdl>>
- sample_map;
+ data_map;
websocketpp::lib::mutex m_lock; // mutex for sample_map
};
--
Gitblit v1.9.1