From b26d3de5fa022f4a44648fee24546aff4e1cf5bc Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 07 九月 2023 14:23:58 +0800
Subject: [PATCH] add mic for funasr-wss-client-2pass
---
funasr/runtime/websocket/CMakeLists.txt | 36 +++++
funasr/runtime/websocket/funasr-wss-client-2pass.cpp | 273 +++++++++++++++++++++++++++++++++++----------
funasr/runtime/websocket/microphone.h | 16 ++
funasr/runtime/websocket/microphone.cpp | 27 ++++
4 files changed, 288 insertions(+), 64 deletions(-)
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index dd71174..d975a85 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -7,6 +7,7 @@
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
+option(ENABLE_PORTAUDIO "Whether to build websocket server" ON)
if(ENABLE_WEBSOCKET)
# cmake_policy(SET CMP0135 NEW)
@@ -38,6 +39,37 @@
endif()
+if(ENABLE_PORTAUDIO)
+ include(FetchContent)
+
+ set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
+ set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
+ set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
+
+ FetchContent_Declare(portaudio
+ URL
+ ${portaudio_URL}
+ ${portaudio_URL2}
+ URL_HASH ${portaudio_HASH}
+ )
+
+ FetchContent_GetProperties(portaudio)
+ if(NOT portaudio_POPULATED)
+ message(STATUS "Downloading portaudio from ${portaudio_URL}")
+ FetchContent_Populate(portaudio)
+ endif()
+ message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
+ message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
+
+ add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
+ if(NOT WIN32)
+ target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
+ else()
+ install(TARGETS portaudio DESTINATION ..)
+ endif()
+
+endif()
+
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
@@ -61,9 +93,9 @@
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
add_executable(funasr-wss-client "funasr-wss-client.cpp")
-add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp")
+add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp")
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
-target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)
diff --git a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
index e52e316..9010c86 100644
--- a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
@@ -17,6 +17,7 @@
#define ASIO_STANDALONE 1
#include <glog/logging.h>
+#include "portaudio.h"
#include <atomic>
#include <fstream>
@@ -30,6 +31,7 @@
#include "audio.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
+#include "microphone.h"
/**
* Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -123,7 +125,6 @@
if (ec) {
LOG(ERROR) << "Error closing connection " << ec.message();
}
-
}
}
}
@@ -131,7 +132,7 @@
// This method will block until the connection is complete
void run(const std::string& uri, const std::vector<string>& wav_list,
const std::vector<string>& wav_ids, std::string asr_mode,
- std::vector<int> chunk_size) {
+ std::vector<int> chunk_size, bool is_record=false) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@@ -152,8 +153,11 @@
// Create a thread to run the ASIO io_service event loop
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
-
- send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+ if(is_record){
+ send_rec_data(asr_mode, chunk_size);
+ }else{
+ send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+ }
WaitABit();
@@ -264,16 +268,11 @@
send_block = len - offset;
}
m_client.send(m_hdl, iArray + offset, send_block * sizeof(short),
- websocketpp::frame::opcode::binary, ec);
+ websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
LOG(INFO) << "sended data len=" << len * sizeof(short);
- // The most likely error that we will get is that the connection is
- // not in the right state. Usually this means we tried to send a
- // message to a connection that was closed or in the process of
- // closing. While many errors here can be easily recovered from,
- // in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@@ -300,11 +299,6 @@
}
LOG(INFO) << "sended data len=" << len;
- // The most likely error that we will get is that the connection is
- // not in the right state. Usually this means we tried to send a
- // message to a connection that was closed or in the process of
- // closing. While many errors here can be easily recovered from,
- // in this simple example, we'll stop the data loop.
if (ec) {
m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
@@ -317,6 +311,137 @@
ec);
WaitABit();
}
+
+ static int RecordCallback(const void* inputBuffer, void* outputBuffer,
+ unsigned long framesPerBuffer, const PaStreamCallbackTimeInfo* timeInfo,
+ PaStreamCallbackFlags statusFlags, void* userData)
+ {
+ std::vector<float>* buffer = static_cast<std::vector<float>*>(userData);
+ const float* input = static_cast<const float*>(inputBuffer);
+
+ for (unsigned int i = 0; i < framesPerBuffer; i++)
+ {
+ buffer->push_back(input[i]);
+ }
+
+ return paContinue;
+ }
+
+ void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector) {
+ // first message
+ bool wait = false;
+ while (1) {
+ {
+ scoped_lock guard(m_lock);
+ // If the connection has been closed, stop generating data
+ if (m_done) {
+ break;
+ }
+ // If the connection hasn't been opened yet wait a bit and retry
+ if (!m_open) {
+ wait = true;
+ } else {
+ break;
+ }
+ }
+
+ if (wait) {
+ // LOG(INFO) << "wait.." << m_open;
+ WaitABit();
+ continue;
+ }
+ }
+ websocketpp::lib::error_code ec;
+
+ nlohmann::json jsonbegin;
+ nlohmann::json chunk_size = nlohmann::json::array();
+ chunk_size.push_back(chunk_vector[0]);
+ chunk_size.push_back(chunk_vector[1]);
+ chunk_size.push_back(chunk_vector[2]);
+ jsonbegin["mode"] = asr_mode;
+ jsonbegin["chunk_size"] = chunk_size;
+ jsonbegin["wav_name"] = "record";
+ jsonbegin["wav_format"] = "pcm";
+ jsonbegin["is_speaking"] = true;
+ m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+ ec);
+ // mic
+ Microphone mic;
+ PaDeviceIndex num_devices = Pa_GetDeviceCount();
+ LOG(INFO) << "Num devices: " << num_devices;
+
+ PaStreamParameters param;
+
+ param.device = Pa_GetDefaultInputDevice();
+ if (param.device == paNoDevice) {
+ LOG(INFO) << "No default input device found";
+ exit(EXIT_FAILURE);
+ }
+ LOG(INFO) << "Use default device: " << param.device;
+
+ const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
+ LOG(INFO) << " Name: " << info->name;
+ LOG(INFO) << " Max input channels: " << info->maxInputChannels;
+
+ param.channelCount = 1;
+ param.sampleFormat = paFloat32;
+
+ param.suggestedLatency = info->defaultLowInputLatency;
+ param.hostApiSpecificStreamInfo = nullptr;
+ float sample_rate = 16000;
+
+ PaStream *stream;
+ std::vector<float> buffer;
+ PaError err =
+ Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */
+ sample_rate,
+ 0, // frames per buffer
+ paClipOff, // we won't output out of range samples
+ // so don't bother clipping them
+ RecordCallback, &buffer);
+ if (err != paNoError) {
+ LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+
+ err = Pa_StartStream(stream);
+ LOG(INFO) << "Started: ";
+
+ if (err != paNoError) {
+ LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+
+ while(true){
+ int len = buffer.size();
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buffer[i] * 32768);
+ }
+
+ m_client.send(m_hdl, iArray, len * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ buffer.clear();
+
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ }
+ Pa_Sleep(20); // sleep for 20ms
+ }
+
+ nlohmann::json jsonresult;
+ jsonresult["is_speaking"] = false;
+ m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+
+ err = Pa_CloseStream(stream);
+ if (err != paNoError) {
+ LOG(INFO) << "portaudio error: " << Pa_GetErrorText(err);
+ exit(EXIT_FAILURE);
+ }
+ }
+
websocketpp::client<T> m_client;
private:
@@ -331,7 +456,7 @@
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
- TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
+ TCLAP::CmdLine cmd("funasr-wss-client-2pass", ' ', "1.0");
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
"127.0.0.1", "string");
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095",
@@ -340,7 +465,11 @@
"", "wav-path",
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
"asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
- true, "", "string");
+ false, "", "string");
+ TCLAP::ValueArg<int> record_(
+ "", "record",
+ "record is 1 means use record", false, 0,
+ "int");
TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass",
false, "2pass", "string");
TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size",
@@ -357,6 +486,7 @@
cmd.add(port_);
cmd.add(wav_path_);
cmd.add(asr_mode_);
+ cmd.add(record_);
cmd.add(chunk_size_);
cmd.add(thread_num_);
cmd.add(is_ssl_);
@@ -382,6 +512,7 @@
int threads_num = thread_num_.getValue();
int is_ssl = is_ssl_.getValue();
+ int is_record = record_.getValue();
std::string uri = "";
if (is_ssl == 1) {
@@ -390,60 +521,78 @@
uri = "ws://" + server_ip + ":" + port;
}
- // read wav_path
- std::vector<string> wav_list;
- std::vector<string> wav_ids;
- string default_id = "wav_default_id";
- if (IsTargetFile(wav_path, "scp")) {
- ifstream in(wav_path);
- if (!in.is_open()) {
- printf("Failed to open scp file");
- return 0;
- }
- string line;
- while (getline(in, line)) {
- istringstream iss(line);
- string column1, column2;
- iss >> column1 >> column2;
- wav_list.emplace_back(column2);
- wav_ids.emplace_back(column1);
- }
- in.close();
- } else {
- wav_list.emplace_back(wav_path);
- wav_ids.emplace_back(default_id);
- }
-
- for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
- std::vector<websocketpp::lib::thread> client_threads;
- for (size_t i = 0; i < threads_num; i++) {
- if (wav_i + i >= wav_list.size()) {
- break;
- }
+ if(is_record == 1){
std::vector<string> tmp_wav_list;
std::vector<string> tmp_wav_ids;
- tmp_wav_list.emplace_back(wav_list[wav_i + i]);
- tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
- client_threads.emplace_back(
- [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
- if (is_ssl == 1) {
- WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
- } else {
- WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+ }
- c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
- }
- });
+ }else{
+ // read wav_path
+ std::vector<string> wav_list;
+ std::vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ if (IsTargetFile(wav_path, "scp")) {
+ ifstream in(wav_path);
+ if (!in.is_open()) {
+ printf("Failed to open scp file");
+ return 0;
+ }
+ string line;
+ while (getline(in, line)) {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ } else {
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
}
- for (auto& t : client_threads) {
- t.join();
+ for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
+ std::vector<websocketpp::lib::thread> client_threads;
+ for (size_t i = 0; i < threads_num; i++) {
+ if (wav_i + i >= wav_list.size()) {
+ break;
+ }
+ std::vector<string> tmp_wav_list;
+ std::vector<string> tmp_wav_ids;
+
+ tmp_wav_list.emplace_back(wav_list[wav_i + i]);
+ tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+
+ client_threads.emplace_back(
+ [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+
+ c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+ }
+ });
+ }
+
+ for (auto& t : client_threads) {
+ t.join();
+ }
}
}
}
diff --git a/funasr/runtime/websocket/microphone.cpp b/funasr/runtime/websocket/microphone.cpp
new file mode 100644
index 0000000..c8b7d5f
--- /dev/null
+++ b/funasr/runtime/websocket/microphone.cpp
@@ -0,0 +1,27 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+
+#include "microphone.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "portaudio.h" // NOLINT
+
+Microphone::Microphone() {
+ PaError err = Pa_Initialize();
+ if (err != paNoError) {
+ LOG(ERROR)<<"portaudio error: " << Pa_GetErrorText(err);
+ exit(-1);
+ }
+}
+
+Microphone::~Microphone() {
+ PaError err = Pa_Terminate();
+ if (err != paNoError) {
+ LOG(ERROR)<<"portaudio error: " << Pa_GetErrorText(err);
+ exit(-1);
+ }
+}
diff --git a/funasr/runtime/websocket/microphone.h b/funasr/runtime/websocket/microphone.h
new file mode 100644
index 0000000..250a815
--- /dev/null
+++ b/funasr/runtime/websocket/microphone.h
@@ -0,0 +1,16 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+
+#ifndef WEBSOCKET_MICROPHONE_H_
+#define WEBSOCKET_MICROPHONE_H_
+#include <glog/logging.h>
+
+class Microphone {
+ public:
+ Microphone();
+ ~Microphone();
+};
+
+#endif // WEBSOCKET_MICROPHONE_H_
--
Gitblit v1.9.1