zhifu gao
2023-05-13 f03a604204bbe0c79e53b01237a37e88683938c6
Merge pull request #505 from zhaomingwork/cpp-python-websocket-compatible

python websocket send binary bytes directly
5个文件已修改
182 ■■■■ 已修改文件
funasr/runtime/python/websocket/ws_client.py 54 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_server_online.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketclient.cpp 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketsrv.cpp 60 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocketsrv.h 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ws_client.py
@@ -84,18 +84,20 @@
                    rate=RATE,
                    input=True,
                    frames_per_buffer=CHUNK)
    message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
    voices.put(message)
    is_speaking = True
    while True:
        data = stream.read(CHUNK)
        data = data.decode('ISO-8859-1')
        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished})
        message = data
        
        voices.put(message)
        await asyncio.sleep(0.005)
async def record_from_scp():
async def record_from_scp(chunk_begin,chunk_size):
    import wave
    global voices
    is_finished = False
@@ -104,6 +106,8 @@
        wavs = f_scp.readlines()
    else:
        wavs = [args.audio_in]
    if chunk_size>0:
        wavs=wavs[chunk_begin:chunk_begin+chunk_size]
    for wav in wavs:
        wav_splits = wav.strip().split()
        wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
@@ -122,14 +126,20 @@
        stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2)
        chunk_num = (len(audio_bytes)-1)//stride + 1
        # print(stride)
        # send first time
        message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "wav_name": wav_name,"is_speaking": True})
        voices.put(message)
        is_speaking = True
        for i in range(chunk_num):
            if i == chunk_num-1:
                is_speaking = False
            beg = i*stride
            data = audio_bytes[beg:beg+stride]
            data = data.decode('ISO-8859-1')
            message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished, "wav_name": wav_name})
            message = data
            voices.put(message)
            if i == chunk_num-1:
                is_speaking = False
                message = json.dumps({"is_speaking": is_speaking})
            voices.put(message)
            # print("data_chunk: ", len(data_chunk))
            # print(voices.qsize())
@@ -213,27 +223,47 @@
            traceback.print_exc()
            exit(0)
async def ws_client(id):
async def ws_client(id,chunk_begin,chunk_size):
    global websocket
    uri = "ws://{}:{}".format(args.host, args.port)
    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
        if args.audio_in is not None:
            task = asyncio.create_task(record_from_scp())
            task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
        else:
            task = asyncio.create_task(record_microphone())
        task2 = asyncio.create_task(ws_send())
        task3 = asyncio.create_task(message(id))
        await asyncio.gather(task, task2, task3)
def one_thread(id):
   asyncio.get_event_loop().run_until_complete(ws_client(id))
def one_thread(id,chunk_begin,chunk_size):
   asyncio.get_event_loop().run_until_complete(ws_client(id,chunk_begin,chunk_size))
   asyncio.get_event_loop().run_forever()
if __name__ == '__main__':
    # calculate the number of wavs for each preocess
    if args.audio_in.endswith(".scp"):
        f_scp = open(args.audio_in)
        wavs = f_scp.readlines()
    else:
        wavs = [args.audio_in]
    total_len=len(wavs)
    if total_len>=args.test_thread_num:
         chunk_size=int((total_len)/args.test_thread_num)
         remain_wavs=total_len-chunk_size*args.test_thread_num
    else:
         chunk_size=0
    process_list = []
    chunk_begin=0
    for i in range(args.test_thread_num):   
        p = Process(target=one_thread,args=(i,))
        now_chunk_size= chunk_size
        if remain_wavs>0:
            now_chunk_size=chunk_size+1
            remain_wavs=remain_wavs-1
        # process i handle wavs at chunk_begin and size of now_chunk_size
        p = Process(target=one_thread,args=(i,chunk_begin,now_chunk_size))
        chunk_begin=chunk_begin+now_chunk_size
        p.start()
        process_list.append(p)
funasr/runtime/python/websocket/ws_server_online.py
@@ -41,25 +41,37 @@
    global websocket_users
    websocket_users.add(websocket)
    websocket.param_dict_asr_online = {"cache": dict()}
    print("new user connected",flush=True)
    try:
        async for message in websocket:
            message = json.loads(message)
            is_finished = message["is_finished"]
            if not is_finished:
                audio = bytes(message['audio'], 'ISO-8859-1')
                is_speaking = message["is_speaking"]
                websocket.param_dict_asr_online["is_final"] = not is_speaking
                websocket.wav_name = message.get("wav_name", "demo")
                websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"]
                
                frames_asr_online.append(audio)
                if len(frames_asr_online) % message["chunk_interval"] == 0 or not is_speaking:
            if isinstance(message,str):
              messagejson = json.loads(message)
              if "is_speaking" in messagejson:
                  websocket.is_speaking = messagejson["is_speaking"]
                  websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
              if "is_finished" in messagejson:
                  websocket.is_speaking = False
                  websocket.param_dict_asr_online["is_final"] = True
              if "chunk_interval" in messagejson:
                  websocket.chunk_interval=messagejson["chunk_interval"]
              if "wav_name" in messagejson:
                  websocket.wav_name = messagejson.get("wav_name", "demo")
              if "chunk_size" in messagejson:
                  websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
            # if has bytes in buffer or message is bytes
            if len(frames_asr_online)>0 or not isinstance(message,str):
               if not isinstance(message,str):
                 frames_asr_online.append(message)
               if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
                    audio_in = b"".join(frames_asr_online)
                    if not websocket.is_speaking:
                       #padding 0.5s at end gurantee that asr engine can fire out last word
                       audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
                    await async_asr_online(websocket,audio_in)
                    frames_asr_online = []
     
    except websockets.ConnectionClosed:
@@ -70,6 +82,7 @@
    except Exception as e:
        print("Exception:", e)
 
async def async_asr_online(websocket,audio_in):
            if len(audio_in) > 0:
                audio_in = load_bytes(audio_in)
funasr/runtime/websocket/websocketclient.cpp
@@ -13,6 +13,7 @@
#include <websocketpp/config/asio_no_tls_client.hpp>
#include "audio.h"
#include "nlohmann/json.hpp"
/**
 * Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -156,6 +157,19 @@
      }
    }
    websocketpp::lib::error_code ec;
    nlohmann::json jsonbegin;
    nlohmann::json chunk_size = nlohmann::json::array();
    chunk_size.push_back(5);
    chunk_size.push_back(0);
    chunk_size.push_back(5);
    jsonbegin["chunk_size"] = chunk_size;
    jsonbegin["chunk_interval"] = 10;
    jsonbegin["wav_name"] = "damo";
    jsonbegin["is_speaking"] = true;
    m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                  ec);
    // fetch wav data use asr engine api
    while (audio.Fetch(buff, len, flag) > 0) {
      short iArray[len];
@@ -181,8 +195,10 @@
      wait_a_bit();
    }
    m_client.send(m_hdl, "Done", websocketpp::frame::opcode::text, ec);
    nlohmann::json jsonresult;
    jsonresult["is_speaking"] = false;
    m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                  ec);
    wait_a_bit();
  }
funasr/runtime/websocket/websocketsrv.cpp
@@ -34,6 +34,14 @@
      websocketpp::lib::error_code ec;
      nlohmann::json jsonresult;        // result json
      jsonresult["text"] = asr_result;  // put result in 'text'
      jsonresult["mode"] = "offline";
      std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
      auto it_data = data_map.find(hdl);
      if (it_data != data_map.end()) {
        msg_data = it_data->second;
      }
      jsonresult["wav_name"] = msg_data->msg["wav_name"];
      // send the json to client
      server_->send(hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
@@ -43,7 +51,7 @@
                << ",result json=" << jsonresult.dump() << std::endl;
      if (!isonline) {
        //  close the client if it is not online asr
        server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
        // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
        // fout.close();
      }
    }
@@ -56,25 +64,28 @@
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);     // for threads safty
  check_and_clean_connection();  // remove closed connection
  sample_map.emplace(
      hdl, std::make_shared<std::vector<char>>());  // put a new data vector for
                                                    // new connection
  std::cout << "on_open, active connections: " << sample_map.size()
            << std::endl;
  std::shared_ptr<FUNASR_MESSAGE> data_msg =
      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
                                           // connection
  data_msg->samples = std::make_shared<std::vector<char>>();
  data_msg->msg = nlohmann::json::parse("{}");
  data_map.emplace(hdl, data_msg);
  std::cout << "on_open, active connections: " << data_map.size() << std::endl;
}
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);
  sample_map.erase(hdl);  // remove data vector when  connection is closed
  std::cout << "on_close, active connections: " << sample_map.size()
            << std::endl;
  data_map.erase(hdl);  // remove data vector when  connection is closed
  std::cout << "on_close, active connections: " << data_map.size() << std::endl;
}
// remove closed connection
void WebSocketServer::check_and_clean_connection() {
  std::vector<websocketpp::connection_hdl> to_remove;  // remove list
  auto iter = sample_map.begin();
  while (iter != sample_map.end()) {  // loop to find closed connection
  auto iter = data_map.begin();
  while (iter != data_map.end()) {  // loop to find closed connection
    websocketpp::connection_hdl hdl = iter->first;
    server::connection_ptr con = server_->get_con_from_hdl(hdl);
    if (con->get_state() != 1) {  // session::state::open ==1
@@ -83,7 +94,7 @@
    iter++;
  }
  for (auto hdl : to_remove) {
    sample_map.erase(hdl);
    data_map.erase(hdl);
    std::cout << "remove one connection " << std::endl;
  }
}
@@ -91,12 +102,15 @@
                                 message_ptr msg) {
  unique_lock lock(m_lock);
  // find the sample data vector according to one connection
  std::shared_ptr<std::vector<char>> sample_data_p = nullptr;
  auto it = sample_map.find(hdl);
  if (it != sample_map.end()) {
    sample_data_p = it->second;
  std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
    msg_data = it_data->second;
  }
  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
  lock.unlock();
  if (sample_data_p == nullptr) {
    std::cout << "error when fetch sample data vector" << std::endl;
@@ -106,13 +120,22 @@
  const std::string& payload = msg->get_payload();  // get msg type
  switch (msg->get_opcode()) {
    case websocketpp::frame::opcode::text:
      if (payload == "Done") {
    case websocketpp::frame::opcode::text: {
      nlohmann::json jsonresult = nlohmann::json::parse(payload);
      if (jsonresult["wav_name"] != nullptr) {
        msg_data->msg["wav_name"] = jsonresult["wav_name"];
      }
      if (jsonresult["is_speaking"] == false ||
          jsonresult["is_finished"] == true) {
        std::cout << "client done" << std::endl;
        if (isonline) {
          // do_close(ws);
        } else {
          // add padding to the end of the wav data
          std::vector<short> padding(static_cast<short>(0.3 * 16000));
          sample_data_p->insert(sample_data_p->end(), padding.data(),
                                padding.data() + padding.size());
          // for offline, send all receive data to decoder engine
          asio::post(io_decoder_, std::bind(&WebSocketServer::do_decoder, this,
                                            std::move(*(sample_data_p.get())),
@@ -120,6 +143,7 @@
        }
      }
      break;
    }
    case websocketpp::frame::opcode::binary: {
      // recived binary data
      const auto* pcm_data = static_cast<const char*>(payload.data());
funasr/runtime/websocket/websocketsrv.h
@@ -46,6 +46,11 @@
  float snippet_time;
} FUNASR_RECOG_RESULT;
typedef struct {
  nlohmann::json msg;
  std::shared_ptr<std::vector<char>> samples;
} FUNASR_MESSAGE;
class WebSocketServer {
 public:
  WebSocketServer(asio::io_context& io_decoder, server* server_)
@@ -84,9 +89,11 @@
  // use map to keep the received samples data from one connection in offline
  // engine. if for online engline, a data struct is needed(TODO)
  std::map<websocketpp::connection_hdl, std::shared_ptr<std::vector<char>>,
  std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
           std::owner_less<websocketpp::connection_hdl>>
      sample_map;
      data_map;
  websocketpp::lib::mutex m_lock;  // mutex for sample_map
};