#include namespace funasr { WfstDecoder::WfstDecoder(fst::Fst* lm, PhoneSet* phone_set, Vocab* vocab, float glob_beam, float lat_beam, float am_scale) :dec_opts_(glob_beam, lat_beam, am_scale), decodable_(dec_opts_.acoustic_scale), lm_(lm), phone_set_(phone_set), vocab_(vocab) { decoder_ = std::shared_ptr( new kaldi::LatticeFasterOnlineDecoder(*lm_, dec_opts_)); } WfstDecoder::~WfstDecoder() { } void WfstDecoder::StartUtterance() { if (decoder_) { cur_frame_ = 0; cur_token_ = 0; decodable_.Reset(); decoder_->InitDecoding(); } } void WfstDecoder::EndUtterance() { } string WfstDecoder::Search(float *in, int len, int64_t token_num) { string result; if (len == 0) { return ""; } std::vector> logp_vec; int blk_phn_id = phone_set_->GetBlkPhnId(); for (int i = 0; i < len - 1; i++) { std::vector tmp_logp; for (int j = 0; j < token_num; j++) { tmp_logp.push_back((in + i * token_num)[j]); } logp_vec.push_back(tmp_logp); } for (int i = 0; i < logp_vec.size(); i++) { cur_frame_++; decodable_.AcceptLoglikes(logp_vec[i]); decoder_->AdvanceDecoding(&decodable_, 1); cur_token_++; } if (cur_token_ > 0) { std::vector words; kaldi::Lattice lattice; decoder_->GetBestPath(&lattice, false); std::vector alignment; kaldi::LatticeWeight weight; fst::GetLinearSymbolSequence(lattice, &alignment, &words, &weight); result = vocab_->Vector2StringV2(words); } return result; } string WfstDecoder::FinalizeDecode(bool is_stamp, std::vector us_alphas, std::vector us_cif_peak) { string result; if (cur_token_ > 0) { std::vector words; kaldi::Lattice lattice; decodable_.SetFinished(); decoder_->FinalizeDecoding(); decoder_->GetBestPath(&lattice, true); std::vector alignment; kaldi::LatticeWeight weight; fst::GetLinearSymbolSequence(lattice, &alignment, &words, &weight); if(!is_stamp){ return vocab_->Vector2StringV2(words); }else{ std::vector char_list; std::vector> timestamp_list; std::string res_str; vocab_->Vector2String(words, char_list); // split chinese word to char std::vector split_chars; for(auto& word:char_list){ std::vector word2char; SplitChiEngCharacters(word, word2char); split_chars.insert(split_chars.end(), word2char.begin(), word2char.end()); } // std::vector raw_char(char_list); TimestampOnnx(us_alphas, us_cif_peak, split_chars, res_str, timestamp_list); return PostProcess(split_chars, timestamp_list); } } return result; } void WfstDecoder::LoadHwsRes(int inc_bias, unordered_map &hws_map) { try { if (!hws_map.empty()) { bias_lm_ = std::make_shared(hws_map, inc_bias, *phone_set_, *vocab_); decoder_->SetBiasLm(bias_lm_); } } catch (std::exception const &e) { LOG(ERROR) << "Error when load wfst hotwords resource: " << e.what(); exit(0); } } void WfstDecoder::UnloadHwsRes() { if (bias_lm_) { decoder_->ClearBiasLm(); bias_lm_.reset(); } } } // namespace funasr