| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | |
| | | #include "precomp.h" |
| | | |
| | | using namespace std; |
| | | using namespace paraformer; |
| | | |
| | | Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc) |
| | | namespace funasr { |
| | | |
| | | Paraformer::Paraformer() |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{ |
| | | string model_path; |
| | | string cmvn_path; |
| | | string config_path; |
| | | } |
| | | |
| | | // VAD model |
| | | if(use_vad){ |
| | | string vad_path = PathAppend(path, "vad_model.onnx"); |
| | | string mvn_path = PathAppend(path, "vad.mvn"); |
| | | vad_handle = make_unique<FsmnVad>(); |
| | | vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES); |
| | | } |
| | | |
| | | // PUNC model |
| | | if(use_punc){ |
| | | punc_handle = make_unique<CTTransformer>(path, thread_num); |
| | | } |
| | | |
| | | if(quantize) |
| | | { |
| | | model_path = PathAppend(path, "model_quant.onnx"); |
| | | }else{ |
| | | model_path = PathAppend(path, "model.onnx"); |
| | | } |
| | | cmvn_path = PathAppend(path, "am.mvn"); |
| | | config_path = PathAppend(path, "config.yaml"); |
| | | |
| | | void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){ |
| | | // knf options |
| | | fbank_opts.frame_opts.dither = 0; |
| | | fbank_opts.mel_opts.num_bins = 80; |
| | |
| | | // DisableCpuMemArena can improve performance |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | #ifdef _WIN32 |
| | | wstring wstrPath = strToWstr(model_path); |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options); |
| | | #else |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options); |
| | | #endif |
| | | try { |
| | | m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options); |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am onnx model: " << e.what(); |
| | | exit(0); |
| | | } |
| | | |
| | | string strName; |
| | | GetInputName(m_session.get(), strName); |
| | |
| | | m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : m_strOutputNames) |
| | | m_szOutputNames.push_back(item.c_str()); |
| | | vocab = new Vocab(config_path.c_str()); |
| | | LoadCmvn(cmvn_path.c_str()); |
| | | vocab = new Vocab(am_config.c_str()); |
| | | LoadCmvn(am_cmvn.c_str()); |
| | | } |
| | | |
| | | Paraformer::~Paraformer() |
| | |
| | | |
| | | void Paraformer::Reset() |
| | | { |
| | | } |
| | | |
| | | vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){ |
| | | return vad_handle->Infer(pcm_data); |
| | | } |
| | | |
| | | string Paraformer::AddPunc(const char* sz_input){ |
| | | return punc_handle->AddPunc(sz_input); |
| | | } |
| | | |
| | | vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) { |
| | |
| | | void Paraformer::LoadCmvn(const char *filename) |
| | | { |
| | | ifstream cmvn_stream(filename); |
| | | if (!cmvn_stream.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << filename; |
| | | exit(0); |
| | | } |
| | | string line; |
| | | |
| | | while (getline(cmvn_stream, line)) { |
| | |
| | | } |
| | | } |
| | | |
| | | string Paraformer::GreedySearch(float * in, int n_len ) |
| | | string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums) |
| | | { |
| | | vector<int> hyps; |
| | | int Tmax = n_len; |
| | | for (int i = 0; i < Tmax; i++) { |
| | | int max_idx; |
| | | float max_val; |
| | | FindMax(in + i * 8404, 8404, max_val, max_idx); |
| | | FindMax(in + i * token_nums, token_nums, max_val, max_idx); |
| | | hyps.push_back(max_idx); |
| | | } |
| | | |
| | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float* floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | | auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>(); |
| | | result = GreedySearch(floatData, *encoder_out_lens); |
| | | result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | | printf(e.what()); |
| | | LOG(ERROR)<<e.what(); |
| | | } |
| | | |
| | | return result; |
| | |
| | | string Paraformer::ForwardChunk(float* din, int len, int flag) |
| | | { |
| | | |
| | | printf("Not Imp!!!!!!\n"); |
| | | return "Hello"; |
| | | LOG(ERROR)<<"Not Imp!!!!!!"; |
| | | return ""; |
| | | } |
| | | |
| | | string Paraformer::Rescoring() |
| | | { |
| | | printf("Not Imp!!!!!!\n"); |
| | | return "Hello"; |
| | | LOG(ERROR)<<"Not Imp!!!!!!"; |
| | | return ""; |
| | | } |
| | | } // namespace funasr |