/** * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. * MIT License (https://opensource.org/licenses/MIT) */ #include "precomp.h" using namespace std; namespace funasr { ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector chunk_size) :para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{ InitOnline( para_handle_->fbank_opts_, para_handle_->encoder_session_, para_handle_->decoder_session_, para_handle_->en_szInputNames_, para_handle_->en_szOutputNames_, para_handle_->de_szInputNames_, para_handle_->de_szOutputNames_, para_handle_->means_list_, para_handle_->vars_list_); InitCache(); } void ParaformerOnline::InitOnline( knf::FbankOptions &fbank_opts, std::shared_ptr &encoder_session, std::shared_ptr &decoder_session, vector &en_szInputNames, vector &en_szOutputNames, vector &de_szInputNames, vector &de_szOutputNames, vector &means_list, vector &vars_list){ fbank_opts_ = fbank_opts; encoder_session_ = encoder_session; decoder_session_ = decoder_session; en_szInputNames_ = en_szInputNames; en_szOutputNames_ = en_szOutputNames; de_szInputNames_ = de_szInputNames; de_szOutputNames_ = de_szOutputNames; means_list_ = means_list; vars_list_ = vars_list; frame_length = para_handle_->frame_length; frame_shift = para_handle_->frame_shift; n_mels = para_handle_->n_mels; lfr_m = para_handle_->lfr_m; lfr_n = para_handle_->lfr_n; encoder_size = para_handle_->encoder_size; fsmn_layers = para_handle_->fsmn_layers; fsmn_lorder = para_handle_->fsmn_lorder; fsmn_dims = para_handle_->fsmn_dims; cif_threshold = para_handle_->cif_threshold; tail_alphas = para_handle_->tail_alphas; // other vars sqrt_factor = std::sqrt(encoder_size); for(int i=0; i> &wav_feats, std::vector &waves) { knf::OnlineFbank fbank(fbank_opts_); // cache merge waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end()); int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_); // Send the audio after the last frame shift position to the cache input_cache_.clear(); input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end()); if (frame_number == 0) { return; } // Delete audio that haven't undergone fbank processing waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end()); std::vector buf(waves.size()); for (int32_t i = 0; i != waves.size(); ++i) { buf[i] = waves[i] * 32768; } fbank.AcceptWaveform(sample_rate, buf.data(), buf.size()); int32_t frames = fbank.NumFramesReady(); for (int32_t i = 0; i != frames; ++i) { const float *frame = fbank.GetFrame(i); vector frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins); wav_feats.emplace_back(frame_vector); } } void ParaformerOnline::ExtractFeats(float sample_rate, vector> &wav_feats, vector &waves, bool input_finished) { FbankKaldi(sample_rate, wav_feats, waves); // cache deal & online lfr,cmvn if (wav_feats.size() > 0) { if (!reserve_waveforms_.empty()) { waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end()); } if (lfr_splice_cache_.empty()) { for (int i = 0; i < (lfr_m - 1) / 2; i++) { lfr_splice_cache_.emplace_back(wav_feats[0]); } } if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) { wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end()); int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1; int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0; int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished); int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame); reserve_waveforms_.clear(); reserve_waveforms_.insert(reserve_waveforms_.begin(), waves.begin() + reserve_frame_idx * frame_shift_sample_length_, waves.begin() + frame_from_waves * frame_shift_sample_length_); int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_; waves.erase(waves.begin() + sample_length, waves.end()); } else { reserve_waveforms_.clear(); reserve_waveforms_.insert(reserve_waveforms_.begin(), waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end()); lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end()); } } else { if (input_finished) { if (!reserve_waveforms_.empty()) { waves = reserve_waveforms_; } wav_feats = lfr_splice_cache_; OnlineLfrCmvn(wav_feats, input_finished); } } if(input_finished){ ResetCache(); } } int ParaformerOnline::OnlineLfrCmvn(vector> &wav_feats, bool input_finished) { vector> out_feats; int T = wav_feats.size(); int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n); int lfr_splice_frame_idxs = T_lrf; vector p; for (int i = 0; i < T_lrf; i++) { if (lfr_m <= T - i * lfr_n) { for (int j = 0; j < lfr_m; j++) { p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end()); } out_feats.emplace_back(p); p.clear(); } else { if (input_finished) { int num_padding = lfr_m - (T - i * lfr_n); for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) { p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end()); } for (int j = 0; j < num_padding; j++) { p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end()); } out_feats.emplace_back(p); } else { lfr_splice_frame_idxs = i; break; } } } lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n); lfr_splice_cache_.clear(); lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end()); // Apply cmvn for (auto &out_feat: out_feats) { for (int j = 0; j < means_list_.size(); j++) { out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j]; } } wav_feats = out_feats; return lfr_splice_frame_idxs; } void ParaformerOnline::GetPosEmb(std::vector> &wav_feats, int timesteps, int feat_dim) { int start_idx = start_idx_cache_; start_idx_cache_ += timesteps; int mm = start_idx_cache_; int i; float scale = -0.0330119726594128; std::vector tmp(mm*feat_dim); for (i = 0; i < feat_dim/2; i++) { float tmptime = exp(i * scale); int j; for (j = 0; j < mm; j++) { int sin_idx = j * feat_dim + i; int cos_idx = j * feat_dim + i + feat_dim/2; float coe = tmptime * (j + 1); tmp[sin_idx] = sin(coe); tmp[cos_idx] = cos(coe); } } for (i = start_idx; i < start_idx + timesteps; i++) { for (int j = 0; j < feat_dim; j++) { wav_feats[i-start_idx][j] += tmp[i*feat_dim+j]; } } } void ParaformerOnline::CifSearch(std::vector> hidden, std::vector alphas, bool is_final, std::vector>& list_frame) { try{ int hidden_size = 0; if(hidden.size() > 0){ hidden_size = hidden[0].size(); } // cache int i,j; int chunk_size_pre = chunk_size[0]; for (i = 0; i < chunk_size_pre; i++) alphas[i] = 0.0; int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0); for (i = chunk_size_suf; i < alphas.size(); i++){ alphas[i] = 0.0; } if(hidden_cache_.size()>0){ hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end()); alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end()); hidden_cache_.clear(); alphas_cache_.clear(); } if (is_last_chunk) { std::vector tail_hidden(hidden_size, 0); hidden.emplace_back(tail_hidden); alphas.emplace_back(tail_alphas); } float intergrate = 0.0; int len_time = alphas.size(); std::vector frames(hidden_size, 0); std::vector list_fire; for (i = 0; i < len_time; i++) { float alpha = alphas[i]; if (alpha + intergrate < cif_threshold) { intergrate += alpha; list_fire.emplace_back(intergrate); for (j = 0; j < hidden_size; j++) { frames[j] += alpha * hidden[i][j]; } } else { for (j = 0; j < hidden_size; j++) { frames[j] += (cif_threshold - intergrate) * hidden[i][j]; } std::vector frames_cp(frames); list_frame.emplace_back(frames_cp); intergrate += alpha; list_fire.emplace_back(intergrate); intergrate -= cif_threshold; for (j = 0; j < hidden_size; j++) { frames[j] = intergrate * hidden[i][j]; } } } // cache alphas_cache_.emplace_back(intergrate); if (intergrate > 0.0) { std::vector hidden_cache(hidden_size, 0); for (i = 0; i < hidden_size; i++) { hidden_cache[i] = frames[i] / intergrate; } hidden_cache_.emplace_back(hidden_cache); } else { std::vector frames_cp(frames); hidden_cache_.emplace_back(frames_cp); } }catch (std::exception const &e) { LOG(ERROR)< hidden_cache(encoder_size, 0); hidden_cache_.emplace_back(hidden_cache); alphas_cache_.emplace_back(0); // feats std::vector feat_cache(feat_dims, 0); for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){ feats_cache_.emplace_back(feat_cache); } // fsmn cache #ifdef _WIN_X86 Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); #else Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); #endif const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder}; for(int l=0; l( m_memoryInfo, fsmn_init_cache_.data(), fsmn_init_cache_.size(), fsmn_shape_, 3); decoder_onnx.emplace_back(std::move(onnx_fsmn_cache)); } }; void ParaformerOnline::Reset() { InitCache(); } void ParaformerOnline::ResetCache() { reserve_waveforms_.clear(); input_cache_.clear(); lfr_splice_cache_.clear(); } void ParaformerOnline::AddOverlapChunk(std::vector> &wav_feats, bool input_finished){ wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end()); if(input_finished){ feats_cache_.clear(); feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end()); if(!is_last_chunk){ int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size(); std::vector tmp(feat_dims, 0); for(int i=0; i> &chunk_feats, bool input_finished) { string result; try{ int32_t num_frames = chunk_feats.size(); #ifdef _WIN_X86 Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); #else Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); #endif const int64_t input_shape_[3] = {1, num_frames, feat_dims}; std::vector wav_feats; for (const auto &chunk_feat: chunk_feats) { wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end()); } Ort::Value onnx_feats = Ort::Value::CreateTensor( m_memoryInfo, wav_feats.data(), wav_feats.size(), input_shape_, 3); const int64_t paraformer_length_shape[1] = {1}; std::vector paraformer_length; paraformer_length.emplace_back(num_frames); Ort::Value onnx_feats_len = Ort::Value::CreateTensor( m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1); std::vector input_onnx; input_onnx.emplace_back(std::move(onnx_feats)); input_onnx.emplace_back(std::move(onnx_feats_len)); auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size()); // get enc_vec std::vector enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape(); float* enc_data = encoder_tensor[0].GetTensorMutableData(); std::vector> enc_vec(enc_shape[1], std::vector(enc_shape[2])); for (int i = 0; i < enc_shape[1]; i++) { for (int j = 0; j < enc_shape[2]; j++) { enc_vec[i][j] = enc_data[i * enc_shape[2] + j]; } } // get alpha_vec std::vector alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape(); float* alpha_data = encoder_tensor[2].GetTensorMutableData(); std::vector alpha_vec(alpha_shape[1]); for (int i = 0; i < alpha_shape[1]; i++) { alpha_vec[i] = alpha_data[i]; } std::vector> list_frame; CifSearch(enc_vec, alpha_vec, input_finished, list_frame); if(list_frame.size()>0){ // enc decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0])); // enc_lens decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1])); // acoustic_embeds const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()}; std::vector emb_input; for (const auto &list_frame_: list_frame) { emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end()); } Ort::Value onnx_emb = Ort::Value::CreateTensor( m_memoryInfo, emb_input.data(), emb_input.size(), emb_shape_, 3); decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb)); // acoustic_embeds_len const int64_t emb_length_shape[1] = {1}; std::vector emb_length; emb_length.emplace_back(list_frame.size()); Ort::Value onnx_emb_len = Ort::Value::CreateTensor( m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1); decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len)); auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size()); // fsmn cache try{ decoder_onnx.clear(); }catch (std::exception const &e) { LOG(ERROR)< decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape(); float* float_data = decoder_tensor[0].GetTensorMutableData(); result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]); } }catch (std::exception const &e) { LOG(ERROR)<> wav_feats; std::vector waves(din, din+len); string result=""; try{ if(len <16*60 && input_finished && !is_first_chunk){ is_last_chunk = true; wav_feats = feats_cache_; result = ForwardChunk(wav_feats, is_last_chunk); // reset ResetCache(); Reset(); return result; } if(is_first_chunk){ is_first_chunk = false; } ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished); if(wav_feats.size() == 0){ return result; } for (auto& row : wav_feats) { for (auto& val : row) { val *= sqrt_factor; } } GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size()); if(input_finished){ if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){ is_last_chunk = true; AddOverlapChunk(wav_feats, input_finished); }else{ // first chunk std::vector> first_chunk; first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end()); AddOverlapChunk(first_chunk, input_finished); string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk); // last chunk is_last_chunk = true; std::vector> last_chunk; last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end()); AddOverlapChunk(last_chunk, input_finished); string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk); result = str_first_chunk+str_last_chunk; // reset ResetCache(); Reset(); return result; } }else{ AddOverlapChunk(wav_feats, input_finished); } result = ForwardChunk(wav_feats, is_last_chunk); if(input_finished){ // reset ResetCache(); Reset(); } }catch (std::exception const &e) { LOG(ERROR)<