From 651737380b2be42ae5182a777abb0938a36aedc1 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 09 八月 2023 16:48:02 +0800
Subject: [PATCH] Merge branch 'main' into dev_wjm_modelscope
---
funasr/runtime/onnxruntime/src/paraformer-online.cpp | 555 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 555 insertions(+), 0 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.cpp b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
new file mode 100644
index 0000000..267d30a
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -0,0 +1,555 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+
+using namespace std;
+
+namespace funasr {
+
+ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
+:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
+ InitOnline(
+ para_handle_->fbank_opts_,
+ para_handle_->encoder_session_,
+ para_handle_->decoder_session_,
+ para_handle_->en_szInputNames_,
+ para_handle_->en_szOutputNames_,
+ para_handle_->de_szInputNames_,
+ para_handle_->de_szOutputNames_,
+ para_handle_->means_list_,
+ para_handle_->vars_list_);
+ InitCache();
+}
+
+void ParaformerOnline::InitOnline(
+ knf::FbankOptions &fbank_opts,
+ std::shared_ptr<Ort::Session> &encoder_session,
+ std::shared_ptr<Ort::Session> &decoder_session,
+ vector<const char*> &en_szInputNames,
+ vector<const char*> &en_szOutputNames,
+ vector<const char*> &de_szInputNames,
+ vector<const char*> &de_szOutputNames,
+ vector<float> &means_list,
+ vector<float> &vars_list){
+ fbank_opts_ = fbank_opts;
+ encoder_session_ = encoder_session;
+ decoder_session_ = decoder_session;
+ en_szInputNames_ = en_szInputNames;
+ en_szOutputNames_ = en_szOutputNames;
+ de_szInputNames_ = de_szInputNames;
+ de_szOutputNames_ = de_szOutputNames;
+ means_list_ = means_list;
+ vars_list_ = vars_list;
+
+ frame_length = para_handle_->frame_length;
+ frame_shift = para_handle_->frame_shift;
+ n_mels = para_handle_->n_mels;
+ lfr_m = para_handle_->lfr_m;
+ lfr_n = para_handle_->lfr_n;
+ encoder_size = para_handle_->encoder_size;
+ fsmn_layers = para_handle_->fsmn_layers;
+ fsmn_lorder = para_handle_->fsmn_lorder;
+ fsmn_dims = para_handle_->fsmn_dims;
+ cif_threshold = para_handle_->cif_threshold;
+ tail_alphas = para_handle_->tail_alphas;
+
+ // other vars
+ sqrt_factor = std::sqrt(encoder_size);
+ for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
+ fsmn_init_cache_.emplace_back(0);
+ }
+ chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
+}
+
+void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
+ std::vector<float> &waves) {
+ knf::OnlineFbank fbank(fbank_opts_);
+ // cache merge
+ waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+ int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+ // Send the audio after the last frame shift position to the cache
+ input_cache_.clear();
+ input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+ if (frame_number == 0) {
+ return;
+ }
+ // Delete audio that haven't undergone fbank processing
+ waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+ std::vector<float> buf(waves.size());
+ for (int32_t i = 0; i != waves.size(); ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+ int32_t frames = fbank.NumFramesReady();
+ for (int32_t i = 0; i != frames; ++i) {
+ const float *frame = fbank.GetFrame(i);
+ vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+ wav_feats.emplace_back(frame_vector);
+ }
+}
+
+void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
+ vector<float> &waves, bool input_finished) {
+ FbankKaldi(sample_rate, wav_feats, waves);
+ // cache deal & online lfr,cmvn
+ if (wav_feats.size() > 0) {
+ if (!reserve_waveforms_.empty()) {
+ waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+ }
+ if (lfr_splice_cache_.empty()) {
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ lfr_splice_cache_.emplace_back(wav_feats[0]);
+ }
+ }
+ if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+ wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+ int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+ int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+ int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
+ int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+ waves.begin() + frame_from_waves * frame_shift_sample_length_);
+ int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+ waves.erase(waves.begin() + sample_length, waves.end());
+ } else {
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+ lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
+ }
+ } else {
+ if (input_finished) {
+ if (!reserve_waveforms_.empty()) {
+ waves = reserve_waveforms_;
+ }
+ wav_feats = lfr_splice_cache_;
+ if(wav_feats.size() == 0){
+ LOG(ERROR) << "wav_feats's size is 0";
+ }else{
+ OnlineLfrCmvn(wav_feats, input_finished);
+ }
+ }
+ }
+ if(input_finished){
+ ResetCache();
+ }
+}
+
+int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
+ vector<vector<float>> out_feats;
+ int T = wav_feats.size();
+ int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
+ int lfr_splice_frame_idxs = T_lrf;
+ vector<float> p;
+ for (int i = 0; i < T_lrf; i++) {
+ if (lfr_m <= T - i * lfr_n) {
+ for (int j = 0; j < lfr_m; j++) {
+ p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ } else {
+ if (input_finished) {
+ int num_padding = lfr_m - (T - i * lfr_n);
+ for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
+ p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+ }
+ for (int j = 0; j < num_padding; j++) {
+ p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
+ }
+ out_feats.emplace_back(p);
+ } else {
+ lfr_splice_frame_idxs = i;
+ break;
+ }
+ }
+ }
+ lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+ lfr_splice_cache_.clear();
+ lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
+
+ // Apply cmvn
+ for (auto &out_feat: out_feats) {
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+ }
+ }
+ wav_feats = out_feats;
+ return lfr_splice_frame_idxs;
+}
+
+void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
+{
+ int start_idx = start_idx_cache_;
+ start_idx_cache_ += timesteps;
+ int mm = start_idx_cache_;
+
+ int i;
+ float scale = -0.0330119726594128;
+
+ std::vector<float> tmp(mm*feat_dim);
+
+ for (i = 0; i < feat_dim/2; i++) {
+ float tmptime = exp(i * scale);
+ int j;
+ for (j = 0; j < mm; j++) {
+ int sin_idx = j * feat_dim + i;
+ int cos_idx = j * feat_dim + i + feat_dim/2;
+ float coe = tmptime * (j + 1);
+ tmp[sin_idx] = sin(coe);
+ tmp[cos_idx] = cos(coe);
+ }
+ }
+
+ for (i = start_idx; i < start_idx + timesteps; i++) {
+ for (int j = 0; j < feat_dim; j++) {
+ wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
+ }
+ }
+}
+
+void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
+{
+ try{
+ int hidden_size = 0;
+ if(hidden.size() > 0){
+ hidden_size = hidden[0].size();
+ }
+ // cache
+ int i,j;
+ int chunk_size_pre = chunk_size[0];
+ for (i = 0; i < chunk_size_pre; i++)
+ alphas[i] = 0.0;
+
+ int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
+ for (i = chunk_size_suf; i < alphas.size(); i++){
+ alphas[i] = 0.0;
+ }
+
+ if(hidden_cache_.size()>0){
+ hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
+ alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
+ hidden_cache_.clear();
+ alphas_cache_.clear();
+ }
+
+ if (is_last_chunk) {
+ std::vector<float> tail_hidden(hidden_size, 0);
+ hidden.emplace_back(tail_hidden);
+ alphas.emplace_back(tail_alphas);
+ }
+
+ float intergrate = 0.0;
+ int len_time = alphas.size();
+ std::vector<float> frames(hidden_size, 0);
+ std::vector<float> list_fire;
+
+ for (i = 0; i < len_time; i++) {
+ float alpha = alphas[i];
+ if (alpha + intergrate < cif_threshold) {
+ intergrate += alpha;
+ list_fire.emplace_back(intergrate);
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] += alpha * hidden[i][j];
+ }
+ } else {
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] += (cif_threshold - intergrate) * hidden[i][j];
+ }
+ std::vector<float> frames_cp(frames);
+ list_frame.emplace_back(frames_cp);
+ intergrate += alpha;
+ list_fire.emplace_back(intergrate);
+ intergrate -= cif_threshold;
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] = intergrate * hidden[i][j];
+ }
+ }
+ }
+
+ // cache
+ alphas_cache_.emplace_back(intergrate);
+ if (intergrate > 0.0) {
+ std::vector<float> hidden_cache(hidden_size, 0);
+ for (i = 0; i < hidden_size; i++) {
+ hidden_cache[i] = frames[i] / intergrate;
+ }
+ hidden_cache_.emplace_back(hidden_cache);
+ } else {
+ std::vector<float> frames_cp(frames);
+ hidden_cache_.emplace_back(frames_cp);
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
+}
+
+void ParaformerOnline::InitCache(){
+
+ start_idx_cache_ = 0;
+ is_first_chunk = true;
+ is_last_chunk = false;
+ hidden_cache_.clear();
+ alphas_cache_.clear();
+ feats_cache_.clear();
+ decoder_onnx.clear();
+
+ // cif cache
+ std::vector<float> hidden_cache(encoder_size, 0);
+ hidden_cache_.emplace_back(hidden_cache);
+ alphas_cache_.emplace_back(0);
+
+ // feats
+ std::vector<float> feat_cache(feat_dims, 0);
+ for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
+ feats_cache_.emplace_back(feat_cache);
+ }
+
+ // fsmn cache
+#ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+ const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
+ for(int l=0; l<fsmn_layers; l++){
+ Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ fsmn_init_cache_.data(),
+ fsmn_init_cache_.size(),
+ fsmn_shape_,
+ 3);
+ decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
+ }
+};
+
+void ParaformerOnline::Reset()
+{
+ InitCache();
+}
+
+void ParaformerOnline::ResetCache() {
+ reserve_waveforms_.clear();
+ input_cache_.clear();
+ lfr_splice_cache_.clear();
+}
+
+void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
+ wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
+ if(input_finished){
+ feats_cache_.clear();
+ feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
+ if(!is_last_chunk){
+ int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
+ std::vector<float> tmp(feat_dims, 0);
+ for(int i=0; i<padding_length; i++){
+ wav_feats.emplace_back(feat_dims);
+ }
+ }
+ }else{
+ feats_cache_.clear();
+ feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());
+ }
+}
+
+string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
+{
+ string result;
+ try{
+ int32_t num_frames = chunk_feats.size();
+
+ #ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+ #else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+ #endif
+ const int64_t input_shape_[3] = {1, num_frames, feat_dims};
+ std::vector<float> wav_feats;
+ for (const auto &chunk_feat: chunk_feats) {
+ wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
+ }
+ Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ wav_feats.data(),
+ wav_feats.size(),
+ input_shape_,
+ 3);
+
+ const int64_t paraformer_length_shape[1] = {1};
+ std::vector<int32_t> paraformer_length;
+ paraformer_length.emplace_back(num_frames);
+ Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
+
+ std::vector<Ort::Value> input_onnx;
+ input_onnx.emplace_back(std::move(onnx_feats));
+ input_onnx.emplace_back(std::move(onnx_feats_len));
+
+ auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
+
+ // get enc_vec
+ std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+ float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
+ std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
+ for (int i = 0; i < enc_shape[1]; i++) {
+ for (int j = 0; j < enc_shape[2]; j++) {
+ enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
+ }
+ }
+
+ // get alpha_vec
+ std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
+ float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
+ std::vector<float> alpha_vec(alpha_shape[1]);
+ for (int i = 0; i < alpha_shape[1]; i++) {
+ alpha_vec[i] = alpha_data[i];
+ }
+
+ std::vector<std::vector<float>> list_frame;
+ CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
+
+
+ if(list_frame.size()>0){
+ // enc
+ decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
+ // enc_lens
+ decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
+
+ // acoustic_embeds
+ const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
+ std::vector<float> emb_input;
+ for (const auto &list_frame_: list_frame) {
+ emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
+ }
+ Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ emb_input.data(),
+ emb_input.size(),
+ emb_shape_,
+ 3);
+ decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
+
+ // acoustic_embeds_len
+ const int64_t emb_length_shape[1] = {1};
+ std::vector<int32_t> emb_length;
+ emb_length.emplace_back(list_frame.size());
+ Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
+ decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
+
+ auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
+ // fsmn cache
+ try{
+ decoder_onnx.clear();
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+ for(int l=0;l<fsmn_layers;l++){
+ decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
+ }
+
+ std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+ float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
+ result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+ return result;
+}
+
+string ParaformerOnline::Forward(float* din, int len, bool input_finished)
+{
+ std::vector<std::vector<float>> wav_feats;
+ std::vector<float> waves(din, din+len);
+
+ string result="";
+ try{
+ if(len <16*60 && input_finished && !is_first_chunk){
+ is_last_chunk = true;
+ wav_feats = feats_cache_;
+ result = ForwardChunk(wav_feats, is_last_chunk);
+ // reset
+ ResetCache();
+ Reset();
+ return result;
+ }
+ if(is_first_chunk){
+ is_first_chunk = false;
+ }
+ ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
+ if(wav_feats.size() == 0){
+ return result;
+ }
+
+ for (auto& row : wav_feats) {
+ for (auto& val : row) {
+ val *= sqrt_factor;
+ }
+ }
+
+ GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
+ if(input_finished){
+ if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
+ is_last_chunk = true;
+ AddOverlapChunk(wav_feats, input_finished);
+ }else{
+ // first chunk
+ std::vector<std::vector<float>> first_chunk;
+ first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
+ AddOverlapChunk(first_chunk, input_finished);
+ string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
+
+ // last chunk
+ is_last_chunk = true;
+ std::vector<std::vector<float>> last_chunk;
+ last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
+ AddOverlapChunk(last_chunk, input_finished);
+ string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
+
+ result = str_first_chunk+str_last_chunk;
+ // reset
+ ResetCache();
+ Reset();
+ return result;
+ }
+ }else{
+ AddOverlapChunk(wav_feats, input_finished);
+ }
+
+ result = ForwardChunk(wav_feats, is_last_chunk);
+ if(input_finished){
+ // reset
+ ResetCache();
+ Reset();
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+
+ return result;
+}
+
+ParaformerOnline::~ParaformerOnline()
+{
+}
+
+string ParaformerOnline::Rescoring()
+{
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
+}
+} // namespace funasr
--
Gitblit v1.9.1