/** * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. * MIT License (https://opensource.org/licenses/MIT) */ #pragma once #include "precomp.h" namespace funasr { class ParaformerOnline : public Model { /** * Author: Speech Lab of DAMO Academy, Alibaba Group * ParaformerOnline: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition * https://arxiv.org/pdf/2206.08317.pdf */ private: void FbankKaldi(float sample_rate, std::vector> &wav_feats, std::vector &waves); int OnlineLfrCmvn(vector> &wav_feats, bool input_finished); void GetPosEmb(std::vector> &wav_feats, int timesteps, int feat_dim); void CifSearch(std::vector> hidden, std::vector alphas, bool is_final, std::vector> &list_frame); static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) { int frame_num = static_cast((sample_length - frame_sample_length) / frame_shift_sample_length + 1); if (frame_num >= 1 && sample_length >= frame_sample_length) return frame_num; else return 0; } void InitOnline( knf::FbankOptions &fbank_opts, std::shared_ptr &encoder_session, std::shared_ptr &decoder_session, vector &en_szInputNames, vector &en_szOutputNames, vector &de_szInputNames, vector &de_szOutputNames, vector &means_list, vector &vars_list); void StartUtterance() { } void EndUtterance() { } Paraformer* para_handle_ = nullptr; // from para_handle_ knf::FbankOptions fbank_opts_; std::shared_ptr encoder_session_ = nullptr; std::shared_ptr decoder_session_ = nullptr; Ort::SessionOptions session_options_; vector en_szInputNames_; vector en_szOutputNames_; vector de_szInputNames_; vector de_szOutputNames_; vector means_list_; vector vars_list_; // configs from para_handle_ int frame_length = 25; int frame_shift = 10; int n_mels = 80; int lfr_m = PARA_LFR_M; int lfr_n = PARA_LFR_N; int encoder_size = 512; int fsmn_layers = 16; int fsmn_lorder = 10; int fsmn_dims = 512; float cif_threshold = 1.0; float tail_alphas = 0.45; // configs int feat_dims = lfr_m*n_mels; std::vector chunk_size = {5,10,5}; int frame_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_length; int frame_shift_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_shift; // The reserved waveforms by fbank std::vector reserve_waveforms_; // waveforms reserved after last shift position std::vector input_cache_; // lfr reserved cache std::vector> lfr_splice_cache_; // position index cache int start_idx_cache_ = 0; // cif alpha std::vector alphas_cache_; std::vector> hidden_cache_; std::vector> feats_cache_; // fsmn init caches std::vector fsmn_init_cache_; std::vector decoder_onnx; bool is_first_chunk = true; bool is_last_chunk = false; double sqrt_factor; public: ParaformerOnline(Paraformer* para_handle, std::vector chunk_size); ~ParaformerOnline(); void Reset(); void ResetCache(); void InitCache(); void ExtractFeats(float sample_rate, vector> &wav_feats, vector &waves, bool input_finished); void AddOverlapChunk(std::vector> &wav_feats, bool input_finished); string ForwardChunk(std::vector> &wav_feats, bool input_finished); string Forward(float* din, int len, bool input_finished, const std::vector> &hw_emb={{0.0}}, void* wfst_decoder=nullptr); string Rescoring(); int GetAsrSampleRate() { return para_handle_->asr_sample_rate; }; // 2pass std::string online_res; int chunk_len; }; } // namespace funasr