| | |
| | | using namespace paraformer; |
| | | |
| | | ModelImp::ModelImp(const char* path,int nNumThread, bool quantize) |
| | | { |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{ |
| | | string model_path; |
| | | string vocab_path; |
| | | string cmvn_path; |
| | | string config_path; |
| | | |
| | | if(quantize) |
| | | { |
| | | model_path = pathAppend(path, "model_quant.onnx"); |
| | | }else{ |
| | | model_path = pathAppend(path, "model.onnx"); |
| | | } |
| | | vocab_path = pathAppend(path, "vocab.txt"); |
| | | cmvn_path = pathAppend(path, "am.mvn"); |
| | | config_path = pathAppend(path, "config.yaml"); |
| | | |
| | | fe = new FeatureExtract(3); |
| | | fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size); |
| | | fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size); |
| | | memset(fft_input, 0, sizeof(float) * fft_size); |
| | | plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE); |
| | | |
| | | //sessionOptions.SetInterOpNumThreads(1); |
| | | sessionOptions.SetIntraOpNumThreads(nNumThread); |
| | |
| | | |
| | | #ifdef _WIN32 |
| | | wstring wstrPath = strToWstr(model_path); |
| | | m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions); |
| | | #else |
| | | m_session = new Ort::Session(env, model_path.c_str(), sessionOptions); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions); |
| | | #endif |
| | | |
| | | string strName; |
| | | getInputName(m_session, strName); |
| | | getInputName(m_session.get(), strName); |
| | | m_strInputNames.push_back(strName.c_str()); |
| | | getInputName(m_session, strName,1); |
| | | getInputName(m_session.get(), strName,1); |
| | | m_strInputNames.push_back(strName); |
| | | |
| | | getOutputName(m_session, strName); |
| | | getOutputName(m_session.get(), strName); |
| | | m_strOutputNames.push_back(strName); |
| | | getOutputName(m_session, strName,1); |
| | | getOutputName(m_session.get(), strName,1); |
| | | m_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : m_strInputNames) |
| | | m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : m_strOutputNames) |
| | | m_szOutputNames.push_back(item.c_str()); |
| | | vocab = new Vocab(vocab_path.c_str()); |
| | | vocab = new Vocab(config_path.c_str()); |
| | | load_cmvn(cmvn_path.c_str()); |
| | | } |
| | | |
| | | ModelImp::~ModelImp() |
| | | { |
| | | if(fe) |
| | | delete fe; |
| | | if (m_session) |
| | | { |
| | | delete m_session; |
| | | m_session = nullptr; |
| | | } |
| | | if(vocab) |
| | | delete vocab; |
| | | fftwf_free(fft_input); |
| | | fftwf_free(fft_out); |
| | | fftwf_destroy_plan(plan); |
| | | fftwf_cleanup(); |
| | | } |
| | | |
| | | void ModelImp::reset() |
| | | { |
| | | fe->reset(); |
| | | } |
| | | |
| | | void ModelImp::apply_lfr(Tensor<float>*& din) |
| | |
| | | din = tmp; |
| | | } |
| | | |
| | | void ModelImp::load_cmvn(const char *filename) |
| | | { |
| | | ifstream cmvn_stream(filename); |
| | | string line; |
| | | |
| | | while (getline(cmvn_stream, line)) { |
| | | istringstream iss(line); |
| | | vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}}; |
| | | if (line_item[0] == "<AddShift>") { |
| | | getline(cmvn_stream, line); |
| | | istringstream means_lines_stream(line); |
| | | vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}}; |
| | | if (means_lines[0] == "<LearnRateCoef>") { |
| | | for (int j = 3; j < means_lines.size() - 1; j++) { |
| | | means_list.push_back(stof(means_lines[j])); |
| | | } |
| | | continue; |
| | | } |
| | | } |
| | | else if (line_item[0] == "<Rescale>") { |
| | | getline(cmvn_stream, line); |
| | | istringstream vars_lines_stream(line); |
| | | vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}}; |
| | | if (vars_lines[0] == "<LearnRateCoef>") { |
| | | for (int j = 3; j < vars_lines.size() - 1; j++) { |
| | | vars_list.push_back(stof(vars_lines[j])*scale); |
| | | } |
| | | continue; |
| | | } |
| | | } |
| | | } |
| | | } |
| | | |
| | | void ModelImp::apply_cmvn(Tensor<float>* din) |
| | | { |
| | | const float* var; |
| | | const float* mean; |
| | | float scale = 22.6274169979695; |
| | | var = vars_list.data(); |
| | | mean= means_list.data(); |
| | | |
| | | int m = din->size[2]; |
| | | int n = din->size[3]; |
| | | |
| | | var = (const float*)paraformer_cmvn_var_hex; |
| | | mean = (const float*)paraformer_cmvn_mean_hex; |
| | | for (int i = 0; i < m; i++) { |
| | | for (int j = 0; j < n; j++) { |
| | | int idx = i * n + j; |
| | |
| | | |
| | | string ModelImp::forward(float* din, int len, int flag) |
| | | { |
| | | |
| | | Tensor<float>* in; |
| | | fe->insert(din, len, flag); |
| | | FeatureExtract* fe = new FeatureExtract(3); |
| | | fe->reset(); |
| | | fe->insert(plan, din, len, flag); |
| | | fe->fetch(in); |
| | | apply_lfr(in); |
| | | apply_cmvn(in); |
| | | Ort::RunOptions run_option; |
| | | |
| | | #ifdef _WIN_X86 |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| | | #else |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | #endif |
| | | |
| | | std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] }; |
| | | Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo, |
| | |
| | | auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float* floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | | auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>(); |
| | |
| | | result = ""; |
| | | } |
| | | |
| | | |
| | | if(in) |
| | | if(in){ |
| | | delete in; |
| | | in = nullptr; |
| | | } |
| | | if(fe){ |
| | | delete fe; |
| | | fe = nullptr; |
| | | } |
| | | |
| | | return result; |
| | | } |