| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | |
| | | #include "precomp.h" |
| | | |
| | | using namespace std; |
| | | using namespace paraformer; |
| | | |
| | | Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc) |
| | | Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num) |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{ |
| | | string model_path; |
| | | string cmvn_path; |
| | | string config_path; |
| | | |
| | | // VAD model |
| | | if(use_vad){ |
| | | string vad_path = PathAppend(path, "vad_model.onnx"); |
| | | string mvn_path = PathAppend(path, "vad.mvn"); |
| | | if(model_path.find(VAD_MODEL_PATH) != model_path.end()){ |
| | | use_vad = true; |
| | | string vad_model_path; |
| | | string vad_cmvn_path; |
| | | string vad_config_path; |
| | | |
| | | try{ |
| | | vad_model_path = model_path.at(VAD_MODEL_PATH); |
| | | vad_cmvn_path = model_path.at(VAD_CMVN_PATH); |
| | | vad_config_path = model_path.at(VAD_CONFIG_PATH); |
| | | }catch(const out_of_range& e){ |
| | | LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what(); |
| | | exit(0); |
| | | } |
| | | vad_handle = make_unique<FsmnVad>(); |
| | | vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES); |
| | | vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path); |
| | | } |
| | | |
| | | // AM model |
| | | if(model_path.find(AM_MODEL_PATH) != model_path.end()){ |
| | | string am_model_path; |
| | | string am_cmvn_path; |
| | | string am_config_path; |
| | | |
| | | try{ |
| | | am_model_path = model_path.at(AM_MODEL_PATH); |
| | | am_cmvn_path = model_path.at(AM_CMVN_PATH); |
| | | am_config_path = model_path.at(AM_CONFIG_PATH); |
| | | }catch(const out_of_range& e){ |
| | | LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what(); |
| | | exit(0); |
| | | } |
| | | InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num); |
| | | } |
| | | |
| | | // PUNC model |
| | | if(use_punc){ |
| | | punc_handle = make_unique<CTTransformer>(path, thread_num); |
| | | } |
| | | if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){ |
| | | use_punc = true; |
| | | string punc_model_path; |
| | | string punc_config_path; |
| | | |
| | | try{ |
| | | punc_model_path = model_path.at(PUNC_MODEL_PATH); |
| | | punc_config_path = model_path.at(PUNC_CONFIG_PATH); |
| | | }catch(const out_of_range& e){ |
| | | LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what(); |
| | | exit(0); |
| | | } |
| | | |
| | | if(quantize) |
| | | { |
| | | model_path = PathAppend(path, "model_quant.onnx"); |
| | | }else{ |
| | | model_path = PathAppend(path, "model.onnx"); |
| | | punc_handle = make_unique<CTTransformer>(); |
| | | punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num); |
| | | } |
| | | cmvn_path = PathAppend(path, "am.mvn"); |
| | | config_path = PathAppend(path, "config.yaml"); |
| | | } |
| | | |
| | | void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){ |
| | | // knf options |
| | | fbank_opts.frame_opts.dither = 0; |
| | | fbank_opts.mel_opts.num_bins = 80; |
| | |
| | | // DisableCpuMemArena can improve performance |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | #ifdef _WIN32 |
| | | wstring wstrPath = strToWstr(model_path); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options); |
| | | #else |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options); |
| | | #endif |
| | | try { |
| | | m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options); |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am onnx model: " << e.what(); |
| | | exit(0); |
| | | } |
| | | |
| | | string strName; |
| | | GetInputName(m_session.get(), strName); |
| | |
| | | m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : m_strOutputNames) |
| | | m_szOutputNames.push_back(item.c_str()); |
| | | vocab = new Vocab(config_path.c_str()); |
| | | LoadCmvn(cmvn_path.c_str()); |
| | | vocab = new Vocab(am_config.c_str()); |
| | | LoadCmvn(am_cmvn.c_str()); |
| | | } |
| | | |
| | | Paraformer::~Paraformer() |
| | |
| | | void Paraformer::LoadCmvn(const char *filename) |
| | | { |
| | | ifstream cmvn_stream(filename); |
| | | if (!cmvn_stream.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << filename; |
| | | exit(0); |
| | | } |
| | | string line; |
| | | |
| | | while (getline(cmvn_stream, line)) { |