From 651737380b2be42ae5182a777abb0938a36aedc1 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 09 八月 2023 16:48:02 +0800
Subject: [PATCH] Merge branch 'main' into dev_wjm_modelscope

---
 funasr/runtime/onnxruntime/src/paraformer-online.cpp |  555 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 555 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.cpp b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
new file mode 100644
index 0000000..267d30a
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -0,0 +1,555 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+
+using namespace std;
+
+namespace funasr {
+
+ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
+:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
+    InitOnline(
+        para_handle_->fbank_opts_,
+        para_handle_->encoder_session_,
+        para_handle_->decoder_session_,
+        para_handle_->en_szInputNames_,
+        para_handle_->en_szOutputNames_,
+        para_handle_->de_szInputNames_,
+        para_handle_->de_szOutputNames_,
+        para_handle_->means_list_,
+        para_handle_->vars_list_);
+    InitCache();
+}
+
+void ParaformerOnline::InitOnline(
+        knf::FbankOptions &fbank_opts,
+        std::shared_ptr<Ort::Session> &encoder_session,
+        std::shared_ptr<Ort::Session> &decoder_session,
+        vector<const char*> &en_szInputNames,
+        vector<const char*> &en_szOutputNames,
+        vector<const char*> &de_szInputNames,
+        vector<const char*> &de_szOutputNames,
+        vector<float> &means_list,
+        vector<float> &vars_list){
+    fbank_opts_ = fbank_opts;
+    encoder_session_ = encoder_session;
+    decoder_session_ = decoder_session;
+    en_szInputNames_ = en_szInputNames;
+    en_szOutputNames_ = en_szOutputNames;
+    de_szInputNames_ = de_szInputNames;
+    de_szOutputNames_ = de_szOutputNames;
+    means_list_ = means_list;
+    vars_list_ = vars_list;
+
+    frame_length = para_handle_->frame_length;
+    frame_shift = para_handle_->frame_shift;
+    n_mels = para_handle_->n_mels;
+    lfr_m = para_handle_->lfr_m;
+    lfr_n = para_handle_->lfr_n;
+    encoder_size = para_handle_->encoder_size;
+    fsmn_layers = para_handle_->fsmn_layers;
+    fsmn_lorder = para_handle_->fsmn_lorder;
+    fsmn_dims = para_handle_->fsmn_dims;
+    cif_threshold = para_handle_->cif_threshold;
+    tail_alphas = para_handle_->tail_alphas;
+
+    // other vars
+    sqrt_factor = std::sqrt(encoder_size);
+    for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
+        fsmn_init_cache_.emplace_back(0);
+    }
+    chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
+}
+
+void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
+                               std::vector<float> &waves) {
+    knf::OnlineFbank fbank(fbank_opts_);
+    // cache merge
+    waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+    int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+    // Send the audio after the last frame shift position to the cache
+    input_cache_.clear();
+    input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+    if (frame_number == 0) {
+        return;
+    }
+    // Delete audio that haven't undergone fbank processing
+    waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+    std::vector<float> buf(waves.size());
+    for (int32_t i = 0; i != waves.size(); ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+    int32_t frames = fbank.NumFramesReady();
+    for (int32_t i = 0; i != frames; ++i) {
+        const float *frame = fbank.GetFrame(i);
+        vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+        wav_feats.emplace_back(frame_vector);
+    }
+}
+
+void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
+                                 vector<float> &waves, bool input_finished) {
+    FbankKaldi(sample_rate, wav_feats, waves);
+    // cache deal & online lfr,cmvn
+    if (wav_feats.size() > 0) {
+        if (!reserve_waveforms_.empty()) {
+        waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+        }
+        if (lfr_splice_cache_.empty()) {
+        for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+            lfr_splice_cache_.emplace_back(wav_feats[0]);
+        }
+        }
+        if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+        wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+        int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+        int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+        int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
+        int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
+        reserve_waveforms_.clear();
+        reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                    waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+                                    waves.begin() + frame_from_waves * frame_shift_sample_length_);
+        int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+        waves.erase(waves.begin() + sample_length, waves.end());
+        } else {
+        reserve_waveforms_.clear();
+        reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                    waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+        lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
+        }
+    } else {
+        if (input_finished) {
+            if (!reserve_waveforms_.empty()) {
+                waves = reserve_waveforms_;
+            }
+            wav_feats = lfr_splice_cache_;
+            if(wav_feats.size() == 0){
+                LOG(ERROR) << "wav_feats's size is 0";
+            }else{
+                OnlineLfrCmvn(wav_feats, input_finished);
+            }
+        }
+    }
+    if(input_finished){
+        ResetCache();
+    }
+}
+
+int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
+    vector<vector<float>> out_feats;
+    int T = wav_feats.size();
+    int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
+    int lfr_splice_frame_idxs = T_lrf;
+    vector<float> p;
+    for (int i = 0; i < T_lrf; i++) {
+        if (lfr_m <= T - i * lfr_n) {
+            for (int j = 0; j < lfr_m; j++) {
+                p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+            }
+            out_feats.emplace_back(p);
+            p.clear();
+        } else {
+            if (input_finished) {
+                int num_padding = lfr_m - (T - i * lfr_n);
+                for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
+                    p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+                }
+                for (int j = 0; j < num_padding; j++) {
+                    p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
+                }
+                out_feats.emplace_back(p);
+            } else {
+                lfr_splice_frame_idxs = i;
+                break;
+            }
+        }
+    }
+    lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+    lfr_splice_cache_.clear();
+    lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
+
+    // Apply cmvn
+    for (auto &out_feat: out_feats) {
+        for (int j = 0; j < means_list_.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+        }
+    }
+    wav_feats = out_feats;
+    return lfr_splice_frame_idxs;
+}
+
+void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
+{
+    int start_idx = start_idx_cache_;
+    start_idx_cache_ += timesteps;
+    int mm = start_idx_cache_;
+
+    int i;
+    float scale = -0.0330119726594128;
+
+    std::vector<float> tmp(mm*feat_dim);
+
+    for (i = 0; i < feat_dim/2; i++) {
+        float tmptime = exp(i * scale);
+        int j;
+        for (j = 0; j < mm; j++) {
+            int sin_idx = j * feat_dim + i;
+            int cos_idx = j * feat_dim + i + feat_dim/2;
+            float coe = tmptime * (j + 1);
+            tmp[sin_idx] = sin(coe);
+            tmp[cos_idx] = cos(coe);
+        }
+    }
+
+    for (i = start_idx; i < start_idx + timesteps; i++) {
+        for (int j = 0; j < feat_dim; j++) {
+            wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
+        }
+    }
+}
+
+void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
+{
+    try{
+        int hidden_size = 0;
+        if(hidden.size() > 0){
+            hidden_size = hidden[0].size();
+        }
+        // cache
+        int i,j;
+        int chunk_size_pre = chunk_size[0];
+        for (i = 0; i < chunk_size_pre; i++)
+            alphas[i] = 0.0;
+
+        int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
+        for (i = chunk_size_suf; i < alphas.size(); i++){
+            alphas[i] = 0.0;
+        }
+
+        if(hidden_cache_.size()>0){
+            hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
+            alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
+            hidden_cache_.clear();
+            alphas_cache_.clear();
+        }
+        
+        if (is_last_chunk) {
+            std::vector<float> tail_hidden(hidden_size, 0);
+            hidden.emplace_back(tail_hidden);
+            alphas.emplace_back(tail_alphas);
+        }
+
+        float intergrate = 0.0;
+        int len_time = alphas.size();
+        std::vector<float> frames(hidden_size, 0);
+        std::vector<float> list_fire;
+
+        for (i = 0; i < len_time; i++) {
+            float alpha = alphas[i];
+            if (alpha + intergrate < cif_threshold) {
+                intergrate += alpha;
+                list_fire.emplace_back(intergrate);
+                for (j = 0; j < hidden_size; j++) {
+                    frames[j] += alpha * hidden[i][j];
+                }
+            } else {
+                for (j = 0; j < hidden_size; j++) {
+                    frames[j] += (cif_threshold - intergrate) * hidden[i][j];
+                }
+                std::vector<float> frames_cp(frames);
+                list_frame.emplace_back(frames_cp);
+                intergrate += alpha;
+                list_fire.emplace_back(intergrate);
+                intergrate -= cif_threshold;
+                for (j = 0; j < hidden_size; j++) {
+                    frames[j] = intergrate * hidden[i][j];
+                }
+            }
+        }
+
+        // cache
+        alphas_cache_.emplace_back(intergrate);
+        if (intergrate > 0.0) {
+            std::vector<float> hidden_cache(hidden_size, 0);
+            for (i = 0; i < hidden_size; i++) {
+                hidden_cache[i] = frames[i] / intergrate;
+            }
+            hidden_cache_.emplace_back(hidden_cache);
+        } else {
+            std::vector<float> frames_cp(frames);
+            hidden_cache_.emplace_back(frames_cp);
+        }
+    }catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+    }
+}
+
+void ParaformerOnline::InitCache(){
+
+    start_idx_cache_ = 0;
+    is_first_chunk = true;
+    is_last_chunk = false;
+    hidden_cache_.clear();
+    alphas_cache_.clear();
+    feats_cache_.clear();
+    decoder_onnx.clear();
+
+    // cif cache
+    std::vector<float> hidden_cache(encoder_size, 0);
+    hidden_cache_.emplace_back(hidden_cache);
+    alphas_cache_.emplace_back(0);
+
+    // feats
+    std::vector<float> feat_cache(feat_dims, 0);
+    for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
+        feats_cache_.emplace_back(feat_cache);
+    }
+
+    // fsmn cache
+#ifdef _WIN_X86
+    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+    const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
+    for(int l=0; l<fsmn_layers; l++){
+        Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
+            m_memoryInfo,
+            fsmn_init_cache_.data(),
+            fsmn_init_cache_.size(),
+            fsmn_shape_,
+            3);
+        decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
+    }
+};
+
+void ParaformerOnline::Reset()
+{
+    InitCache();
+}
+
+void ParaformerOnline::ResetCache() {
+    reserve_waveforms_.clear();
+    input_cache_.clear();
+    lfr_splice_cache_.clear();
+}
+
+void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
+    wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
+    if(input_finished){
+        feats_cache_.clear();
+        feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
+        if(!is_last_chunk){
+            int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
+            std::vector<float> tmp(feat_dims, 0);
+            for(int i=0; i<padding_length; i++){
+                wav_feats.emplace_back(feat_dims);
+            }
+        }
+    }else{
+        feats_cache_.clear();
+        feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());        
+    }
+}
+
+string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
+{
+    string result;
+    try{
+        int32_t num_frames = chunk_feats.size();
+
+    #ifdef _WIN_X86
+            Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+    #else
+            Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+    #endif
+        const int64_t input_shape_[3] = {1, num_frames, feat_dims};
+        std::vector<float> wav_feats;
+        for (const auto &chunk_feat: chunk_feats) {
+            wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
+        }
+        Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
+            m_memoryInfo,
+            wav_feats.data(),
+            wav_feats.size(),
+            input_shape_,
+            3);
+
+        const int64_t paraformer_length_shape[1] = {1};
+        std::vector<int32_t> paraformer_length;
+        paraformer_length.emplace_back(num_frames);
+        Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
+            m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
+        
+        std::vector<Ort::Value> input_onnx;
+        input_onnx.emplace_back(std::move(onnx_feats));
+        input_onnx.emplace_back(std::move(onnx_feats_len)); 
+        
+        auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
+
+        // get enc_vec
+        std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+        float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
+        std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
+        for (int i = 0; i < enc_shape[1]; i++) {
+            for (int j = 0; j < enc_shape[2]; j++) {
+                enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
+            }
+        }
+
+        // get alpha_vec
+        std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
+        float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
+        std::vector<float> alpha_vec(alpha_shape[1]);
+        for (int i = 0; i < alpha_shape[1]; i++) {
+            alpha_vec[i] = alpha_data[i];
+        } 
+
+        std::vector<std::vector<float>> list_frame;
+        CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
+
+        
+        if(list_frame.size()>0){
+            // enc
+            decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
+            // enc_lens
+            decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
+
+            // acoustic_embeds
+            const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
+            std::vector<float> emb_input;
+            for (const auto &list_frame_: list_frame) {
+                emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
+            }
+            Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
+                m_memoryInfo,
+                emb_input.data(),
+                emb_input.size(),
+                emb_shape_,
+                3);
+            decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
+
+            // acoustic_embeds_len
+            const int64_t emb_length_shape[1] = {1};
+            std::vector<int32_t> emb_length;
+            emb_length.emplace_back(list_frame.size());
+            Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
+                m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
+            decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
+
+            auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
+            // fsmn cache
+            try{
+                decoder_onnx.clear();
+            }catch (std::exception const &e)
+            {
+                LOG(ERROR)<<e.what();
+                return result;
+            }
+            for(int l=0;l<fsmn_layers;l++){
+                decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
+            }
+
+            std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+            float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
+            result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+        }
+    }catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+        return result;
+    }
+    return result;
+}
+
+string ParaformerOnline::Forward(float* din, int len, bool input_finished)
+{
+    std::vector<std::vector<float>> wav_feats;
+    std::vector<float> waves(din, din+len);
+
+    string result="";
+    try{
+        if(len <16*60 && input_finished && !is_first_chunk){
+            is_last_chunk = true;
+            wav_feats = feats_cache_;
+            result = ForwardChunk(wav_feats, is_last_chunk);
+            // reset
+            ResetCache();
+            Reset();
+            return result;
+        }
+        if(is_first_chunk){
+            is_first_chunk = false;
+        }
+        ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
+        if(wav_feats.size() == 0){
+            return result;
+        }
+        
+        for (auto& row : wav_feats) {
+            for (auto& val : row) {
+                val *= sqrt_factor;
+            }
+        }
+
+        GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
+        if(input_finished){
+            if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
+                is_last_chunk = true;
+                AddOverlapChunk(wav_feats, input_finished);
+            }else{
+                // first chunk
+                std::vector<std::vector<float>> first_chunk;
+                first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
+                AddOverlapChunk(first_chunk, input_finished);
+                string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
+
+                // last chunk
+                is_last_chunk = true;
+                std::vector<std::vector<float>> last_chunk;
+                last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
+                AddOverlapChunk(last_chunk, input_finished);
+                string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
+
+                result = str_first_chunk+str_last_chunk;
+                // reset
+                ResetCache();
+                Reset();
+                return result;
+            }
+        }else{
+            AddOverlapChunk(wav_feats, input_finished);
+        }
+
+        result = ForwardChunk(wav_feats, is_last_chunk);
+        if(input_finished){
+            // reset
+            ResetCache();
+            Reset();
+        }
+    }catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+        return result;
+    }
+
+    return result;
+}
+
+ParaformerOnline::~ParaformerOnline()
+{
+}
+
+string ParaformerOnline::Rescoring()
+{
+    LOG(ERROR)<<"Not Imp!!!!!!";
+    return "";
+}
+} // namespace funasr

--
Gitblit v1.9.1