| | |
| | | // [--vad-quant <string>] [--vad-dir <string>] [--quantize |
| | | // <string>] --model-dir <string> [--] [--version] [-h] |
| | | #include "websocket-server.h" |
| | | #include <unistd.h> |
| | | |
| | | using namespace std; |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, |
| | |
| | | } |
| | | int main(int argc, char* argv[]) { |
| | | try { |
| | | |
| | | google::InitGoogleLogging(argv[0]); |
| | | FLAGS_logtostderr = true; |
| | | |
| | | TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> download_model_dir( |
| | | "", "download-model-dir", |
| | | "Download model from Modelscope to download_model_dir", |
| | | false, "", "string"); |
| | | TCLAP::ValueArg<std::string> model_dir( |
| | | "", MODEL_DIR, |
| | | "default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn", |
| | |
| | | "true, load the model of model_quant.onnx in punc_dir", |
| | | false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false, |
| | | TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false, |
| | | "0.0.0.0", "string"); |
| | | TCLAP::ValueArg<int> port("", "port", "port", false, 8889, "int"); |
| | | TCLAP::ValueArg<int> io_thread_num("", "io_thread_num", "io_thread_num", |
| | | TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int"); |
| | | TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num", |
| | | false, 8, "int"); |
| | | TCLAP::ValueArg<int> decoder_thread_num( |
| | | "", "decoder_thread_num", "decoder_thread_num", false, 8, "int"); |
| | | TCLAP::ValueArg<int> model_thread_num("", "model_thread_num", |
| | | "model_thread_num", false, 1, "int"); |
| | | "", "decoder-thread-num", "decoder thread num", false, 8, "int"); |
| | | TCLAP::ValueArg<int> model_thread_num("", "model-thread-num", |
| | | "model thread num", false, 1, "int"); |
| | | |
| | | TCLAP::ValueArg<std::string> certfile("", "certfile", |
| | | "default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.", |
| | |
| | | cmd.add(certfile); |
| | | cmd.add(keyfile); |
| | | |
| | | cmd.add(download_model_dir); |
| | | cmd.add(model_dir); |
| | | cmd.add(quantize); |
| | | cmd.add(vad_dir); |
| | |
| | | GetValue(punc_dir, PUNC_DIR, model_path); |
| | | GetValue(punc_quant, PUNC_QUANT, model_path); |
| | | |
| | | // Download model form Modelscope |
| | | try{ |
| | | std::string s_download_model_dir = download_model_dir.getValue(); |
| | | if(download_model_dir.isSet() && !s_download_model_dir.empty()){ |
| | | if (access(s_download_model_dir.c_str(), F_OK) != 0){ |
| | | LOG(ERROR) << s_download_model_dir << " do not exists."; |
| | | exit(-1); |
| | | } |
| | | std::string s_vad_path = model_path[VAD_DIR]; |
| | | std::string s_asr_path = model_path[MODEL_DIR]; |
| | | std::string s_punc_path = model_path[PUNC_DIR]; |
| | | std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True "; |
| | | if(vad_dir.isSet() && !s_vad_path.empty()){ |
| | | std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir; |
| | | LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: "; |
| | | system(python_cmd_vad.c_str()); |
| | | std::string down_vad_path = s_download_model_dir+"/"+s_vad_path; |
| | | std::string down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx"; |
| | | if (access(down_vad_model.c_str(), F_OK) != 0){ |
| | | LOG(ERROR) << down_vad_model << " do not exists."; |
| | | exit(-1); |
| | | }else{ |
| | | model_path[VAD_DIR]=down_vad_path; |
| | | LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR]; |
| | | } |
| | | }else{ |
| | | LOG(INFO) << "VAD model is not set, use default."; |
| | | } |
| | | if(model_dir.isSet() && !s_asr_path.empty()){ |
| | | std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir; |
| | | LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: "; |
| | | system(python_cmd_asr.c_str()); |
| | | std::string down_asr_path = s_download_model_dir+"/"+s_asr_path; |
| | | std::string down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx"; |
| | | if (access(down_asr_model.c_str(), F_OK) != 0){ |
| | | LOG(ERROR) << down_asr_model << " do not exists."; |
| | | exit(-1); |
| | | }else{ |
| | | model_path[MODEL_DIR]=down_asr_path; |
| | | LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR]; |
| | | } |
| | | }else{ |
| | | LOG(INFO) << "ASR model is not set, use default."; |
| | | } |
| | | if(punc_dir.isSet() && !s_punc_path.empty()){ |
| | | std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir; |
| | | LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: "; |
| | | system(python_cmd_punc.c_str()); |
| | | std::string down_punc_path = s_download_model_dir+"/"+s_punc_path; |
| | | std::string down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx"; |
| | | if (access(down_punc_model.c_str(), F_OK) != 0){ |
| | | LOG(ERROR) << down_punc_model << " do not exists."; |
| | | exit(-1); |
| | | }else{ |
| | | model_path[PUNC_DIR]=down_punc_path; |
| | | LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR]; |
| | | } |
| | | }else{ |
| | | LOG(INFO) << "PUNC model is not set, use default."; |
| | | } |
| | | } |
| | | } catch (std::exception const& e) { |
| | | LOG(ERROR) << "Error: " << e.what(); |
| | | } |
| | | |
| | | std::string s_listen_ip = listen_ip.getValue(); |
| | | int s_port = port.getValue(); |
| | | int s_io_thread_num = io_thread_num.getValue(); |