| | |
| | | LoadCmvn(am_cmvn.c_str()); |
| | | } |
| | | |
| | | // online |
| | | void SenseVoiceSmall::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){ |
| | | |
| | | LoadOnlineConfigFromYaml(am_config.c_str()); |
| | | // knf options |
| | | fbank_opts_.frame_opts.dither = 0; |
| | | fbank_opts_.mel_opts.num_bins = n_mels; |
| | | fbank_opts_.frame_opts.samp_freq = asr_sample_rate; |
| | | fbank_opts_.frame_opts.window_type = window_type; |
| | | fbank_opts_.frame_opts.frame_shift_ms = frame_shift; |
| | | fbank_opts_.frame_opts.frame_length_ms = frame_length; |
| | | fbank_opts_.energy_floor = 0; |
| | | fbank_opts_.mel_opts.debug_mel = false; |
| | | |
| | | // session_options_.SetInterOpNumThreads(1); |
| | | session_options_.SetIntraOpNumThreads(thread_num); |
| | | session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| | | // DisableCpuMemArena can improve performance |
| | | session_options_.DisableCpuMemArena(); |
| | | |
| | | try { |
| | | encoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(en_model).c_str(), session_options_); |
| | | LOG(INFO) << "Successfully load model from " << en_model; |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am encoder model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | |
| | | try { |
| | | decoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(de_model).c_str(), session_options_); |
| | | LOG(INFO) << "Successfully load model from " << de_model; |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am decoder model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | |
| | | // encoder |
| | | string strName; |
| | | GetInputName(encoder_session_.get(), strName); |
| | | en_strInputNames.push_back(strName.c_str()); |
| | | GetInputName(encoder_session_.get(), strName,1); |
| | | en_strInputNames.push_back(strName); |
| | | |
| | | GetOutputName(encoder_session_.get(), strName); |
| | | en_strOutputNames.push_back(strName); |
| | | GetOutputName(encoder_session_.get(), strName,1); |
| | | en_strOutputNames.push_back(strName); |
| | | GetOutputName(encoder_session_.get(), strName,2); |
| | | en_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : en_strInputNames) |
| | | en_szInputNames_.push_back(item.c_str()); |
| | | for (auto& item : en_strOutputNames) |
| | | en_szOutputNames_.push_back(item.c_str()); |
| | | |
| | | // decoder |
| | | int de_input_len = 4 + fsmn_layers; |
| | | int de_out_len = 2 + fsmn_layers; |
| | | for(int i=0;i<de_input_len; i++){ |
| | | GetInputName(decoder_session_.get(), strName, i); |
| | | de_strInputNames.push_back(strName.c_str()); |
| | | } |
| | | |
| | | for(int i=0;i<de_out_len; i++){ |
| | | GetOutputName(decoder_session_.get(), strName,i); |
| | | de_strOutputNames.push_back(strName); |
| | | } |
| | | |
| | | for (auto& item : de_strInputNames) |
| | | de_szInputNames_.push_back(item.c_str()); |
| | | for (auto& item : de_strOutputNames) |
| | | de_szOutputNames_.push_back(item.c_str()); |
| | | |
| | | online_vocab = new Vocab(token_file.c_str()); |
| | | phone_set_ = new PhoneSet(token_file.c_str()); |
| | | LoadCmvn(am_cmvn.c_str()); |
| | | } |
| | | |
| | | // 2pass |
| | | void SenseVoiceSmall::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, |
| | | const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){ |
| | | // online |
| | | InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num); |
| | | |
| | | // offline |
| | | try { |
| | | m_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(am_model).c_str(), session_options_); |
| | | LOG(INFO) << "Successfully load model from " << am_model; |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load am onnx model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | |
| | | GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames); |
| | | GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames); |
| | | vocab = new Vocab(token_file.c_str()); |
| | | } |
| | | |
| | | void SenseVoiceSmall::LoadOnlineConfigFromYaml(const char* filename){ |
| | | |
| | | YAML::Node config; |
| | | try{ |
| | | config = YAML::LoadFile(filename); |
| | | }catch(exception const &e){ |
| | | LOG(ERROR) << "Error loading file, yaml file error or not exist."; |
| | | exit(-1); |
| | | } |
| | | |
| | | try{ |
| | | YAML::Node frontend_conf = config["frontend_conf"]; |
| | | YAML::Node encoder_conf = config["encoder_conf"]; |
| | | YAML::Node decoder_conf = config["decoder_conf"]; |
| | | YAML::Node predictor_conf = config["predictor_conf"]; |
| | | |
| | | this->window_type = frontend_conf["window"].as<string>(); |
| | | this->n_mels = frontend_conf["n_mels"].as<int>(); |
| | | this->frame_length = frontend_conf["frame_length"].as<int>(); |
| | | this->frame_shift = frontend_conf["frame_shift"].as<int>(); |
| | | this->lfr_m = frontend_conf["lfr_m"].as<int>(); |
| | | this->lfr_n = frontend_conf["lfr_n"].as<int>(); |
| | | |
| | | this->encoder_size = encoder_conf["output_size"].as<int>(); |
| | | this->fsmn_dims = encoder_conf["output_size"].as<int>(); |
| | | |
| | | this->fsmn_layers = decoder_conf["num_blocks"].as<int>(); |
| | | this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1; |
| | | |
| | | this->cif_threshold = predictor_conf["threshold"].as<double>(); |
| | | this->tail_alphas = predictor_conf["tail_threshold"].as<double>(); |
| | | |
| | | this->asr_sample_rate = frontend_conf["fs"].as<int>(); |
| | | |
| | | |
| | | }catch(exception const &e){ |
| | | LOG(ERROR) << "Error when load argument from vad config YAML."; |
| | | exit(-1); |
| | | } |
| | | } |
| | | |
| | | void SenseVoiceSmall::LoadConfigFromYaml(const char* filename){ |
| | | |
| | | YAML::Node config; |
| | |
| | | { |
| | | if(vocab){ |
| | | delete vocab; |
| | | } |
| | | if(online_vocab){ |
| | | delete online_vocab; |
| | | } |
| | | if(lm_vocab){ |
| | | delete lm_vocab; |
| | |
| | | return str_lang + str_emo + str_event + " " + text; |
| | | } |
| | | |
| | | string SenseVoiceSmall::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak) |
| | | { |
| | | vector<int> hyps; |
| | | int Tmax = n_len; |
| | | for (int i = 0; i < Tmax; i++) { |
| | | int max_idx; |
| | | float max_val; |
| | | FindMax(in + i * token_nums, token_nums, max_val, max_idx); |
| | | hyps.push_back(max_idx); |
| | | } |
| | | if(!is_stamp){ |
| | | return online_vocab->Vector2StringV2(hyps, language); |
| | | }else{ |
| | | std::vector<string> char_list; |
| | | std::vector<std::vector<float>> timestamp_list; |
| | | std::string res_str; |
| | | online_vocab->Vector2String(hyps, char_list); |
| | | std::vector<string> raw_char(char_list); |
| | | TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list); |
| | | |
| | | return PostProcess(raw_char, timestamp_list); |
| | | } |
| | | } |
| | | |
| | | void SenseVoiceSmall::LfrCmvn(std::vector<std::vector<float>> &asr_feats) { |
| | | |
| | | std::vector<std::vector<float>> out_feats; |