From c2e232451f2f87b1ebdddd6a7f6d8434cb309808 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 九月 2023 14:30:12 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/runtime/websocket/funasr-wss-client-2pass.cpp | 273 ++++++++++++++++++++++++++++++++++++++++++------------
1 files changed, 211 insertions(+), 62 deletions(-)
diff --git a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
index e52e316..9010c86 100644
--- a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
@@ -17,6 +17,7 @@
#define ASIO_STANDALONE 1
#include <glog/logging.h>
+#include "portaudio.h"
#include <atomic>
#include <fstream>
@@ -30,6 +31,7 @@
#include "audio.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
+#include "microphone.h"
/**
* Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -123,7 +125,6 @@
if (ec) {
LOG(ERROR) << "Error closing connection " << ec.message();
}
-
}
}
}
@@ -131,7 +132,7 @@
// This method will block until the connection is complete
void run(const std::string& uri, const std::vector<string>& wav_list,
const std::vector<string>& wav_ids, std::string asr_mode,
- std::vector<int> chunk_size) {
+ std::vector<int> chunk_size, bool is_record=false) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@@ -152,8 +153,11 @@
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
-
- send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+ if(is_record){
+ send_rec_data(asr_mode, chunk_size);
+ }else{
+ send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+ }
WaitABit();
@@ -264,16 +268,11 @@
send_block = len - offset;
}
m_client.send(m_hdl, iArray + offset, send_block * sizeof(short),
- websocketpp::frame::opcode::binary, ec);
+ websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
LOG(INFO) << "sended data len=" << len * sizeof(short);
- // The most likely error that we will get is that the connection is
- // not in the right state. Usually this means we tried to send a
- // message to a connection that was closed or in the process of
- // closing. While many errors here can be easily recovered from,
- // in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@@ -300,11 +299,6 @@
}
LOG(INFO) << "sended data len=" << len;
- // The most likely error that we will get is that the connection is
- // not in the right state. Usually this means we tried to send a
- // message to a connection that was closed or in the process of
- // closing. While many errors here can be easily recovered from,
- // in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@@ -317,6 +311,137 @@
ec);
WaitABit();
}
+
+ static int RecordCallback(const void* inputBuffer, void* outputBuffer,
+ unsigned long framesPerBuffer, const PaStreamCallbackTimeInfo* timeInfo,
+ PaStreamCallbackFlags statusFlags, void* userData)
+ {
+ std::vector<float>* buffer = static_cast<std::vector<float>*>(userData);
+ const float* input = static_cast<const float*>(inputBuffer);
+
+ for (unsigned int i = 0; i < framesPerBuffer; i++)
+ {
+ buffer->push_back(input[i]);
+ }
+
+ return paContinue;
+ }
+
+ void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector) {
+ // first message
+ bool wait = false;
+ while (1) {
+ {
+ scoped_lock guard(m_lock);
+ // If the connection has been closed, stop generating data
+ if (m_done) {
+ break;
+ }
+ // If the connection hasn't been opened yet wait a bit and retry
+ if (!m_open) {
+ wait = true;
+ } else {
+ break;
+ }
+ }
+
+ if (wait) {
+ // LOG(INFO) << "wait.." << m_open;
+ WaitABit();
+ continue;
+ }
+ }
+ websocketpp::lib::error_code ec;
+
+ nlohmann::json jsonbegin;
+ nlohmann::json chunk_size = nlohmann::json::array();
+ chunk_size.push_back(chunk_vector[0]);
+ chunk_size.push_back(chunk_vector[1]);
+ chunk_size.push_back(chunk_vector[2]);
+ jsonbegin["mode"] = asr_mode;
+ jsonbegin["chunk_size"] = chunk_size;
+ jsonbegin["wav_name"] = "record";
+ jsonbegin["wav_format"] = "pcm";
+ jsonbegin["is_speaking"] = true;
+ m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+ ec);
+ // mic
+ Microphone mic;
+ PaDeviceIndex num_devices = Pa_GetDeviceCount();
+ LOG(INFO) << "Num devices: " << num_devices;
+
+ PaStreamParameters param;
+
+ param.device = Pa_GetDefaultInputDevice();
+ if (param.device == paNoDevice) {
+ LOG(INFO) << "No default input device found";
+ exit(EXIT_FAILURE);
+ }
+ LOG(INFO) << "Use default device: " << param.device;
+
+ const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
+ LOG(INFO) << " Name: " << info->name;
+ LOG(INFO) << " Max input channels: " << info->maxInputChannels;
+
+ param.channelCount = 1;
+ param.sampleFormat = paFloat32;
+
+ param.suggestedLatency = info->defaultLowInputLatency;
+ param.hostApiSpecificStreamInfo = nullptr;
+ float sample_rate = 16000;
+
+ PaStream *stream;
+ std::vector<float> buffer;
+ PaError err =
+ Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */
+ sample_rate,
+ 0, // frames per buffer
+ paClipOff, // we won't output out of range samples
+ // so don't bother clipping them
+ RecordCallback, &buffer);
+ if (err != paNoError) {
+ LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+
+ err = Pa_StartStream(stream);
+ LOG(INFO) << "Started: ";
+
+ if (err != paNoError) {
+ LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+
+ while(true){
+ int len = buffer.size();
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buffer[i] * 32768);
+ }
+
+ m_client.send(m_hdl, iArray, len * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ buffer.clear();
+
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ }
+ Pa_Sleep(20); // sleep for 20ms
+ }
+
+ nlohmann::json jsonresult;
+ jsonresult["is_speaking"] = false;
+ m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+
+ err = Pa_CloseStream(stream);
+ if (err != paNoError) {
+ LOG(INFO) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+ }
+
websocketpp::client<T> m_client;
private:
@@ -331,7 +456,7 @@
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
- TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
+ TCLAP::CmdLine cmd("funasr-wss-client-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
"127.0.0.1", "string");
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095",
@@ -340,7 +465,11 @@
"", "wav-path",
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
"asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
- true, "", "string");
+ false, "", "string");
+ TCLAP::ValueArg<int> record_(
+ "", "record",
+ "record is 1 means use record", false, 0,
+ "int");
TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass",
false, "2pass", "string");
TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size",
@@ -357,6 +486,7 @@
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(asr_mode_);
+ cmd.add(record_);
cmd.add(chunk_size_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
@@ -382,6 +512,7 @@
int threads_num = thread_num_.getValue();
int is_ssl = is_ssl_.getValue();
+ int is_record = record_.getValue();
std::string uri = "";
if (is_ssl == 1) {
@@ -390,60 +521,78 @@
uri = "ws://" + server_ip + ":" + port;
}
- // read wav_path
- std::vector<string> wav_list;
- std::vector<string> wav_ids;
- string default_id = "wav_default_id";
- if (IsTargetFile(wav_path, "scp")) {
- ifstream in(wav_path);
- if (!in.is_open()) {
- printf("Failed to open scp file");
- return 0;
- }
- string line;
- while (getline(in, line)) {
- istringstream iss(line);
- string column1, column2;
- iss >> column1 >> column2;
- wav_list.emplace_back(column2);
- wav_ids.emplace_back(column1);
- }
- in.close();
- } else {
- wav_list.emplace_back(wav_path);
- wav_ids.emplace_back(default_id);
- }
-
- for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
- std::vector<websocketpp::lib::thread> client_threads;
- for (size_t i = 0; i < threads_num; i++) {
- if (wav_i + i >= wav_list.size()) {
- break;
- }
+ if(is_record == 1){
std::vector<string> tmp_wav_list;
std::vector<string> tmp_wav_ids;
- tmp_wav_list.emplace_back(wav_list[wav_i + i]);
- tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
- client_threads.emplace_back(
- [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
- if (is_ssl == 1) {
- WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
- } else {
- WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+ }
- c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
- }
- });
+ }else{
+ // read wav_path
+ std::vector<string> wav_list;
+ std::vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ if (IsTargetFile(wav_path, "scp")) {
+ ifstream in(wav_path);
+ if (!in.is_open()) {
+ printf("Failed to open scp file");
+ return 0;
+ }
+ string line;
+ while (getline(in, line)) {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ } else {
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
}
- for (auto& t : client_threads) {
- t.join();
+ for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
+ std::vector<websocketpp::lib::thread> client_threads;
+ for (size_t i = 0; i < threads_num; i++) {
+ if (wav_i + i >= wav_list.size()) {
+ break;
+ }
+ std::vector<string> tmp_wav_list;
+ std::vector<string> tmp_wav_ids;
+
+ tmp_wav_list.emplace_back(wav_list[wav_i + i]);
+ tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+
+ client_threads.emplace_back(
+ [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+ }
+ });
+ }
+
+ for (auto& t : client_threads) {
+ t.join();
+ }
}
}
}
--
Gitblit v1.9.1