| | |
| | | #include <algorithm> |
| | | #include <chrono> |
| | | #include <cmath> |
| | | #include <iostream> |
| | | #include <sstream> |
| | | #include <memory> |
| | | #include <string> |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights |
| | | * Reserved. MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | /* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */ |
| | | |
| | | #include <grpc/grpc.h> |
| | | #include <grpcpp/server.h> |
| | | #include <grpcpp/server_builder.h> |
| | | #include <grpcpp/server_context.h> |
| | | #include <grpcpp/security/server_credentials.h> |
| | | |
| | | #include "paraformer.grpc.pb.h" |
| | | #include "paraformer-server.h" |
| | | #include "tclap/CmdLine.h" |
| | | #include "com-define.h" |
| | | #include "glog/logging.h" |
| | | |
| | | using grpc::Server; |
| | | using grpc::ServerBuilder; |
| | | using grpc::ServerContext; |
| | | using grpc::ServerReader; |
| | | using grpc::ServerReaderWriter; |
| | | using grpc::ServerWriter; |
| | | using grpc::Status; |
| | | GrpcEngine::GrpcEngine( |
| | | grpc::ServerReaderWriter<Response, Request>* stream, |
| | | std::shared_ptr<FUNASR_HANDLE> asr_handler) |
| | | : stream_(std::move(stream)), |
| | | asr_handler_(std::move(asr_handler)) { |
| | | |
| | | using paraformer::Request; |
| | | using paraformer::Response; |
| | | using paraformer::ASR; |
| | | |
| | | ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) { |
| | | AsrHanlde=FunOfflineInit(model_path, 1); |
| | | std::cout << "ASRServicer init" << std::endl; |
| | | init_flag = 0; |
| | | request_ = std::make_shared<Request>(); |
| | | } |
| | | |
| | | void ASRServicer::clear_states(const std::string& user) { |
| | | clear_buffers(user); |
| | | clear_transcriptions(user); |
| | | } |
| | | void GrpcEngine::DecodeThreadFunc() { |
| | | FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_); |
| | | int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes; |
| | | std::vector<std::vector<std::string>> punc_cache(2); |
| | | |
| | | void ASRServicer::clear_buffers(const std::string& user) { |
| | | if (client_buffers.count(user)) { |
| | | client_buffers.erase(user); |
| | | } |
| | | } |
| | | bool is_final = false; |
| | | std::string online_result = ""; |
| | | std::string tpass_result = ""; |
| | | |
| | | void ASRServicer::clear_transcriptions(const std::string& user) { |
| | | if (client_transcription.count(user)) { |
| | | client_transcription.erase(user); |
| | | } |
| | | } |
| | | LOG(INFO) << "Decoder init, start decoding loop with mode"; |
| | | |
| | | void ASRServicer::disconnect(const std::string& user) { |
| | | clear_states(user); |
| | | std::cout << "Disconnecting user: " << user << std::endl; |
| | | } |
| | | while (true) { |
| | | if (audio_buffer_.length() > step || is_end_) { |
| | | if (audio_buffer_.length() <= step && is_end_) { |
| | | is_final = true; |
| | | step = audio_buffer_.length(); |
| | | } |
| | | |
| | | grpc::Status ASRServicer::Recognize( |
| | | grpc::ServerContext* context, |
| | | grpc::ServerReaderWriter<Response, Request>* stream) { |
| | | FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_, |
| | | tpass_online_handler, |
| | | audio_buffer_.c_str(), |
| | | step, |
| | | punc_cache, |
| | | is_final, |
| | | sampling_rate_, |
| | | encoding_, |
| | | mode_); |
| | | audio_buffer_ = audio_buffer_.substr(step); |
| | | |
| | | Request req; |
| | | while (stream->Read(&req)) { |
| | | if (req.isend()) { |
| | | std::cout << "asr end" << std::endl; |
| | | disconnect(req.user()); |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "asr end"})" |
| | | ); |
| | | res.set_user(req.user()); |
| | | res.set_action("terminate"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | } else if (req.speaking()) { |
| | | if (req.audio_data().size() > 0) { |
| | | auto& buf = client_buffers[req.user()]; |
| | | buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end()); |
| | | } |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "speaking"})" |
| | | ); |
| | | res.set_user(req.user()); |
| | | res.set_action("speaking"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | } else if (!req.speaking()) { |
| | | if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) { |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "waiting_for_voice"})" |
| | | ); |
| | | res.set_user(req.user()); |
| | | res.set_action("waiting"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | }else { |
| | | auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); |
| | | if (req.audio_data().size() > 0) { |
| | | auto& buf = client_buffers[req.user()]; |
| | | buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end()); |
| | | } |
| | | std::string tmp_data = this->client_buffers[req.user()]; |
| | | this->clear_states(req.user()); |
| | | |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})" |
| | | ); |
| | | int data_len_int = tmp_data.length(); |
| | | std::string data_len = std::to_string(data_len_int); |
| | | std::stringstream ss; |
| | | ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})"; |
| | | std::string result = ss.str(); |
| | | res.set_sentence(result); |
| | | res.set_user(req.user()); |
| | | res.set_action("decoding"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | if (tmp_data.length() < 800) { //min input_len for asr model |
| | | auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); |
| | | std::string delay_str = std::to_string(end_time - begin_time); |
| | | std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", error: data_is_not_long_enough" << std::endl; |
| | | Response res; |
| | | std::stringstream ss; |
| | | std::string asr_result = ""; |
| | | ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})"; |
| | | std::string result = ss.str(); |
| | | res.set_sentence(result); |
| | | res.set_user(req.user()); |
| | | res.set_action("finish"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | } |
| | | else { |
| | | FUNASR_RESULT Result= FunOfflineInferBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000); |
| | | std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; |
| | | |
| | | auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); |
| | | std::string delay_str = std::to_string(end_time - begin_time); |
| | | |
| | | std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl; |
| | | Response res; |
| | | std::stringstream ss; |
| | | ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})"; |
| | | std::string result = ss.str(); |
| | | res.set_sentence(result); |
| | | res.set_user(req.user()); |
| | | res.set_action("finish"); |
| | | res.set_language(req.language()); |
| | | |
| | | stream->Write(res); |
| | | } |
| | | } |
| | | }else { |
| | | Response res; |
| | | res.set_sentence( |
| | | R"({"success": false, "detail": "error, no condition matched! Unknown reason."})" |
| | | ); |
| | | res.set_user(req.user()); |
| | | res.set_action("terminate"); |
| | | res.set_language(req.language()); |
| | | stream->Write(res); |
| | | if (result) { |
| | | std::string online_message = FunASRGetResult(result, 0); |
| | | online_result += online_message; |
| | | if(online_message != ""){ |
| | | Response response; |
| | | response.set_mode(DecodeMode::online); |
| | | response.set_text(online_message); |
| | | response.set_is_final(is_final); |
| | | stream_->Write(response); |
| | | LOG(INFO) << "send online results: " << online_message; |
| | | } |
| | | std::string tpass_message = FunASRGetTpassResult(result, 0); |
| | | tpass_result += tpass_message; |
| | | if(tpass_message != ""){ |
| | | Response response; |
| | | response.set_mode(DecodeMode::two_pass); |
| | | response.set_text(tpass_message); |
| | | response.set_is_final(is_final); |
| | | stream_->Write(response); |
| | | LOG(INFO) << "send offline results: " << tpass_message; |
| | | } |
| | | FunASRFreeResult(result); |
| | | } |
| | | |
| | | if (is_final) { |
| | | FunTpassOnlineUninit(tpass_online_handler); |
| | | break; |
| | | } |
| | | } |
| | | return Status::OK; |
| | | sleep(0.001); |
| | | } |
| | | } |
| | | |
| | | void RunServer(std::map<std::string, std::string>& model_path) { |
| | | std::string port; |
| | | try{ |
| | | port = model_path.at(PORT_ID); |
| | | }catch(std::exception const &e){ |
| | | printf("Error when read port.\n"); |
| | | exit(0); |
| | | void GrpcEngine::OnSpeechStart() { |
| | | if (request_->chunk_size_size() == 3) { |
| | | for (int i = 0; i < 3; i++) { |
| | | chunk_size_[i] = int(request_->chunk_size(i)); |
| | | } |
| | | std::string server_address; |
| | | server_address = "0.0.0.0:" + port; |
| | | ASRServicer service(model_path); |
| | | } |
| | | std::string chunk_size_str; |
| | | for (int i = 0; i < 3; i++) { |
| | | chunk_size_str = " " + chunk_size_[i]; |
| | | } |
| | | LOG(INFO) << "chunk_size is" << chunk_size_str; |
| | | |
| | | ServerBuilder builder; |
| | | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); |
| | | builder.RegisterService(&service); |
| | | std::unique_ptr<Server> server(builder.BuildAndStart()); |
| | | std::cout << "Server listening on " << server_address << std::endl; |
| | | server->Wait(); |
| | | if (request_->sampling_rate() != 0) { |
| | | sampling_rate_ = request_->sampling_rate(); |
| | | } |
| | | LOG(INFO) << "sampling_rate is " << sampling_rate_; |
| | | |
| | | switch(request_->wav_format()) { |
| | | case WavFormat::pcm: encoding_ = "pcm"; |
| | | } |
| | | LOG(INFO) << "encoding is " << encoding_; |
| | | |
| | | std::string mode_str; |
| | | switch(request_->mode()) { |
| | | case DecodeMode::offline: |
| | | mode_ = ASR_OFFLINE; |
| | | mode_str = "offline"; |
| | | break; |
| | | case DecodeMode::online: |
| | | mode_ = ASR_ONLINE; |
| | | mode_str = "online"; |
| | | break; |
| | | case DecodeMode::two_pass: |
| | | mode_ = ASR_TWO_PASS; |
| | | mode_str = "two_pass"; |
| | | break; |
| | | } |
| | | LOG(INFO) << "decode mode is " << mode_str; |
| | | |
| | | decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this); |
| | | is_start_ = true; |
| | | } |
| | | |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path) |
| | | { |
| | | if (value_arg.isSet()){ |
| | | model_path.insert({key, value_arg.getValue()}); |
| | | LOG(INFO)<< key << " : " << value_arg.getValue(); |
| | | void GrpcEngine::OnSpeechData() { |
| | | audio_buffer_ += request_->audio_data(); |
| | | } |
| | | |
| | | void GrpcEngine::OnSpeechEnd() { |
| | | is_end_ = true; |
| | | LOG(INFO) << "Read all pcm data, wait for decoding thread"; |
| | | if (decode_thread_ != nullptr) { |
| | | decode_thread_->join(); |
| | | } |
| | | } |
| | | |
| | | void GrpcEngine::operator()() { |
| | | try { |
| | | LOG(INFO) << "start engine main loop"; |
| | | while (stream_->Read(request_.get())) { |
| | | LOG(INFO) << "receive data"; |
| | | if (!is_start_) { |
| | | OnSpeechStart(); |
| | | } |
| | | OnSpeechData(); |
| | | if (request_->is_final()) { |
| | | break; |
| | | } |
| | | } |
| | | OnSpeechEnd(); |
| | | LOG(INFO) << "Connect finish"; |
| | | } catch (std::exception const& e) { |
| | | LOG(ERROR) << e.what(); |
| | | } |
| | | } |
| | | |
| | | GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread) |
| | | : config_(config) { |
| | | |
| | | asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread))); |
| | | LOG(INFO) << "GrpcService model loaded"; |
| | | |
| | | std::vector<int> chunk_size = {5, 10, 5}; |
| | | FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size); |
| | | int sampling_rate = 16000; |
| | | int buffer_len = sampling_rate * 1; |
| | | std::string tmp_data(buffer_len, '0'); |
| | | std::vector<std::vector<std::string>> punc_cache(2); |
| | | bool is_final = true; |
| | | std::string encoding = "pcm"; |
| | | FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_, |
| | | tmp_online_handler, |
| | | tmp_data.c_str(), |
| | | buffer_len, |
| | | punc_cache, |
| | | is_final, |
| | | buffer_len, |
| | | encoding, |
| | | ASR_TWO_PASS); |
| | | if (result) { |
| | | FunASRFreeResult(result); |
| | | } |
| | | FunTpassOnlineUninit(tmp_online_handler); |
| | | LOG(INFO) << "GrpcService model warmup"; |
| | | } |
| | | |
| | | grpc::Status GrpcService::Recognize( |
| | | grpc::ServerContext* context, |
| | | grpc::ServerReaderWriter<Response, Request>* stream) { |
| | | LOG(INFO) << "Get Recognize request"; |
| | | GrpcEngine engine( |
| | | stream, |
| | | asr_handler_ |
| | | ); |
| | | |
| | | std::thread t(std::move(engine)); |
| | | t.join(); |
| | | return grpc::Status::OK; |
| | | } |
| | | |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) { |
| | | if (value_arg.isSet()) { |
| | | config.insert({key, value_arg.getValue()}); |
| | | LOG(INFO) << key << " : " << value_arg.getValue(); |
| | | } |
| | | } |
| | | |
| | | int main(int argc, char* argv[]) { |
| | | FLAGS_logtostderr = true; |
| | | google::InitGoogleLogging(argv[0]); |
| | | |
| | | google::InitGoogleLogging(argv[0]); |
| | | FLAGS_logtostderr = true; |
| | | TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); |
| | | TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string"); |
| | | |
| | | TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string"); |
| | | cmd.add(offline_model_dir); |
| | | cmd.add(online_model_dir); |
| | | cmd.add(quantize); |
| | | cmd.add(vad_dir); |
| | | cmd.add(vad_quant); |
| | | cmd.add(punc_dir); |
| | | cmd.add(punc_quant); |
| | | cmd.add(onnx_thread); |
| | | cmd.add(port_id); |
| | | cmd.parse(argc, argv); |
| | | |
| | | cmd.add(model_dir); |
| | | cmd.add(quantize); |
| | | cmd.add(vad_dir); |
| | | cmd.add(vad_quant); |
| | | cmd.add(punc_dir); |
| | | cmd.add(punc_quant); |
| | | cmd.add(port_id); |
| | | cmd.parse(argc, argv); |
| | | std::map<std::string, std::string> config; |
| | | GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config); |
| | | GetValue(online_model_dir, ONLINE_MODEL_DIR, config); |
| | | GetValue(quantize, QUANTIZE, config); |
| | | GetValue(vad_dir, VAD_DIR, config); |
| | | GetValue(vad_quant, VAD_QUANT, config); |
| | | GetValue(punc_dir, PUNC_DIR, config); |
| | | GetValue(punc_quant, PUNC_QUANT, config); |
| | | GetValue(port_id, PORT_ID, config); |
| | | |
| | | std::map<std::string, std::string> model_path; |
| | | GetValue(model_dir, MODEL_DIR, model_path); |
| | | GetValue(quantize, QUANTIZE, model_path); |
| | | GetValue(vad_dir, VAD_DIR, model_path); |
| | | GetValue(vad_quant, VAD_QUANT, model_path); |
| | | GetValue(punc_dir, PUNC_DIR, model_path); |
| | | GetValue(punc_quant, PUNC_QUANT, model_path); |
| | | GetValue(port_id, PORT_ID, model_path); |
| | | std::string port; |
| | | try { |
| | | port = config.at(PORT_ID); |
| | | } catch(std::exception const &e) { |
| | | LOG(INFO) << ("Error when read port."); |
| | | exit(0); |
| | | } |
| | | std::string server_address; |
| | | server_address = "0.0.0.0:" + port; |
| | | GrpcService service(config, onnx_thread); |
| | | |
| | | RunServer(model_path); |
| | | return 0; |
| | | grpc::ServerBuilder builder; |
| | | builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); |
| | | builder.RegisterService(&service); |
| | | std::unique_ptr<grpc::Server> server(builder.BuildAndStart()); |
| | | LOG(INFO) << "Server listening on " << server_address; |
| | | server->Wait(); |
| | | |
| | | return 0; |
| | | } |