From a539392ad48f9d03696587cb49ac595782a6f95f Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期四, 27 四月 2023 15:17:24 +0800
Subject: [PATCH] fix paraformer server for new apis
---
funasr/runtime/grpc/CMakeLists.txt | 21 +++--
funasr/runtime/onnxruntime/include/com-define.h | 1
funasr/runtime/grpc/Readme.md | 44 ++++++++--
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 6 +
funasr/runtime/grpc/paraformer-server.cc | 120 +++++++++++++++++++++--------
funasr/runtime/onnxruntime/src/precomp.h | 5
funasr/runtime/grpc/paraformer-server.h | 9 +
funasr/runtime/onnxruntime/include/libfunasrapi.h | 11 ++
8 files changed, 158 insertions(+), 59 deletions(-)
diff --git a/funasr/runtime/grpc/CMakeLists.txt b/funasr/runtime/grpc/CMakeLists.txt
index c7727d5..98c4787 100644
--- a/funasr/runtime/grpc/CMakeLists.txt
+++ b/funasr/runtime/grpc/CMakeLists.txt
@@ -42,16 +42,22 @@
"${rg_proto}"
DEPENDS "${rg_proto}")
-
# Include generated *.pb.h files
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
-include_directories(../onnxruntime/include/)
-link_directories(../onnxruntime/build/src/)
-link_directories(../onnxruntime/build/third_party/yaml-cpp/)
-
link_directories(${ONNXRUNTIME_DIR}/lib)
+
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
+
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
add_subdirectory("../onnxruntime/src" onnx_src)
+
+include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
+set(BUILD_TESTING OFF)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
# rg_grpc_proto
add_library(rg_grpc_proto
@@ -60,16 +66,13 @@
${rg_proto_srcs}
${rg_proto_hdrs})
-
-
target_link_libraries(rg_grpc_proto
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
-# Targets paraformer_(server)
foreach(_target
- paraformer_server)
+ paraformer-server)
add_executable(${_target}
"${_target}.cc")
target_link_libraries(${_target}
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index 23e618c..da92559 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -4,15 +4,6 @@
### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
-```
-#put onnx-lib & onnx-asr-model into /path/to/asrmodel(eg: /data/asrmodel)
-ls /data/asrmodel/
-onnxruntime-linux-x64-1.14.0 speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-
-#make sure you have config.yaml, am.mvn, model.onnx(or model_quant.onnx) under speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-
-```
-
### Compile and install grpc v1.52.0 in case of grpc bugs
```
export GRPC_INSTALL_DIR=/data/soft/grpc
@@ -46,8 +37,39 @@
### Start grpc paraformer server
```
-Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
-./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
+./cmake/build/paraformer-server --port-id <string> [--punc-config
+ <string>] [--punc-model <string>]
+ --am-config <string> --am-cmvn <string>
+ --am-model <string> [--vad-config
+ <string>] [--vad-cmvn <string>]
+ [--vad-model <string>] [--] [--version]
+ [-h]
+Where:
+ --port-id <string>
+ (required) port id
+
+ --am-config <string>
+ (required) am config path
+ --am-cmvn <string>
+ (required) am cmvn path
+ --am-model <string>
+ (required) am model path
+
+ --punc-config <string>
+ punc config path
+ --punc-model <string>
+ punc model path
+
+ --vad-config <string>
+ vad config path
+ --vad-cmvn <string>
+ vad cmvn path
+ --vad-model <string>
+ vad model path
+
+ Required: --port-id <string> --am-config <string> --am-cmvn <string> --am-model <string>
+ If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
+ If use punc, please add: [--punc-config <string>] [--punc-model <string>]
```
## For the client
diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer-server.cc
similarity index 65%
rename from funasr/runtime/grpc/paraformer_server.cc
rename to funasr/runtime/grpc/paraformer-server.cc
index 2893d4c..31333c9 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer-server.cc
@@ -13,7 +13,10 @@
#include <grpcpp/security/server_credentials.h>
#include "paraformer.grpc.pb.h"
-#include "paraformer_server.h"
+#include "paraformer-server.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "glog/logging.h"
using grpc::Server;
using grpc::ServerBuilder;
@@ -27,10 +30,32 @@
using paraformer::Response;
using paraformer::ASR;
-ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
- AsrHanlde=FunASRInit(model_path, thread_num, quantize);
+ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
+ AsrHanlde=FunASRInit(model_path, 1);
std::cout << "ASRServicer init" << std::endl;
init_flag = 0;
+}
+
+void ASRServicer::clear_states(const std::string& user) {
+ clear_buffers(user);
+ clear_transcriptions(user);
+}
+
+void ASRServicer::clear_buffers(const std::string& user) {
+ if (client_buffers.count(user)) {
+ client_buffers.erase(user);
+ }
+}
+
+void ASRServicer::clear_transcriptions(const std::string& user) {
+ if (client_transcription.count(user)) {
+ client_transcription.erase(user);
+ }
+}
+
+void ASRServicer::disconnect(const std::string& user) {
+ clear_states(user);
+ std::cout << "Disconnecting user: " << user << std::endl;
}
grpc::Status ASRServicer::Recognize(
@@ -38,20 +63,10 @@
grpc::ServerReaderWriter<Response, Request>* stream) {
Request req;
- std::unordered_map<std::string, std::string> client_buffers;
- std::unordered_map<std::string, std::string> client_transcription;
-
while (stream->Read(&req)) {
if (req.isend()) {
std::cout << "asr end" << std::endl;
- // disconnect
- if (client_buffers.count(req.user())) {
- client_buffers.erase(req.user());
- }
- if (client_transcription.count(req.user())) {
- client_transcription.erase(req.user());
- }
-
+ disconnect(req.user());
Response res;
res.set_sentence(
R"({"success": true, "detail": "asr end"})"
@@ -89,14 +104,8 @@
auto& buf = client_buffers[req.user()];
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
}
- std::string tmp_data = client_buffers[req.user()];
- // clear_states
- if (client_buffers.count(req.user())) {
- client_buffers.erase(req.user());
- }
- if (client_transcription.count(req.user())) {
- client_transcription.erase(req.user());
- }
+ std::string tmp_data = this->client_buffers[req.user()];
+ this->clear_states(req.user());
Response res;
res.set_sentence(
@@ -161,10 +170,17 @@
return Status::OK;
}
-void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
+void RunServer(std::map<std::string, std::string>& model_path) {
+ std::string port;
+ try{
+ port = model_path.at(PORT_ID);
+ }catch(std::exception const &e){
+ printf("Error when read port.\n");
+ exit(0);
+ }
std::string server_address;
server_address = "0.0.0.0:" + port;
- ASRServicer service(model_path, thread_num, quantize);
+ ASRServicer service(model_path);
ServerBuilder builder;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
@@ -174,16 +190,54 @@
server->Wait();
}
-int main(int argc, char* argv[]) {
- if (argc < 5)
- {
- printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
- exit(-1);
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
}
+}
- // is quantize
- bool quantize = false;
- std::istringstream(argv[4]) >> std::boolalpha >> quantize;
- RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
+int main(int argc, char* argv[]) {
+
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
+ TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
+
+ TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
+ TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
+ TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
+
+ TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+ TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
+
+ cmd.add(vad_model);
+ cmd.add(vad_cmvn);
+ cmd.add(vad_config);
+ cmd.add(am_model);
+ cmd.add(am_cmvn);
+ cmd.add(am_config);
+ cmd.add(punc_model);
+ cmd.add(punc_config);
+ cmd.add(port_id);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(vad_model, VAD_MODEL_PATH, model_path);
+ GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
+ GetValue(vad_config, VAD_CONFIG_PATH, model_path);
+ GetValue(am_model, AM_MODEL_PATH, model_path);
+ GetValue(am_cmvn, AM_CMVN_PATH, model_path);
+ GetValue(am_config, AM_CONFIG_PATH, model_path);
+ GetValue(punc_model, PUNC_MODEL_PATH, model_path);
+ GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(port_id, PORT_ID, model_path);
+
+ RunServer(model_path);
return 0;
}
diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer-server.h
similarity index 70%
rename from funasr/runtime/grpc/paraformer_server.h
rename to funasr/runtime/grpc/paraformer-server.h
index dba1e45..108e3b6 100644
--- a/funasr/runtime/grpc/paraformer_server.h
+++ b/funasr/runtime/grpc/paraformer-server.h
@@ -37,13 +37,18 @@
float snippet_time;
}FUNASR_RECOG_RESULT;
-
class ASRServicer final : public ASR::Service {
private:
int init_flag;
+ std::unordered_map<std::string, std::string> client_buffers;
+ std::unordered_map<std::string, std::string> client_transcription;
public:
- ASRServicer(const char* model_path, int thread_num, bool quantize);
+ ASRServicer(std::map<std::string, std::string>& model_path);
+ void clear_states(const std::string& user);
+ void clear_buffers(const std::string& user);
+ void clear_transcriptions(const std::string& user);
+ void disconnect(const std::string& user);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
FUNASR_HANDLE AsrHanlde;
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 8c88517..9b7b212 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -24,6 +24,7 @@
#define WAV_PATH "wav-path"
#define WAV_SCP "wav-scp"
#define THREAD_NUM "thread-num"
+#define PORT_ID "port-id"
// vad
#ifndef VAD_SILENCE_DURATION
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 8dca7f4..f65efcc 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -47,10 +47,9 @@
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
-// APIs for funasr
+// // ASR
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
-// if not give a fn_callback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
@@ -62,6 +61,14 @@
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
+// VAD
+_FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+
+_FUNASRAPI FUNASR_RESULT FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+
#ifdef __cplusplus
}
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index 93434bb..01aa38a 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -11,6 +11,12 @@
return mm;
}
+ _FUNASRAPI FUNASR_HANDLE FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ Model* mm = CreateModel(model_path, thread_num);
+ return mm;
+ }
+
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
{
Model* recog_obj = (Model*)handle;
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index cf69ad9..68e0fe8 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -21,8 +21,8 @@
// third part
#include "onnxruntime_run_options_config_keys.h"
#include "onnxruntime_cxx_api.h"
-#include <kaldi-native-fbank/csrc/feature-fbank.h>
-#include <kaldi-native-fbank/csrc/online-feature.h>
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+#include "kaldi-native-fbank/csrc/online-feature.h"
// mine
#include <glog/logging.h>
@@ -40,6 +40,7 @@
#include "util.h"
#include "resample.h"
#include "model.h"
+#include "vad-model.h"
#include "paraformer.h"
#include "libfunasrapi.h"
--
Gitblit v1.9.1