jmwang66
2023-08-09 651737380b2be42ae5182a777abb0938a36aedc1
funasr/runtime/onnxruntime/src/paraformer-online.cpp
New file
@@ -0,0 +1,555 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
using namespace std;
namespace funasr {
ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
    InitOnline(
        para_handle_->fbank_opts_,
        para_handle_->encoder_session_,
        para_handle_->decoder_session_,
        para_handle_->en_szInputNames_,
        para_handle_->en_szOutputNames_,
        para_handle_->de_szInputNames_,
        para_handle_->de_szOutputNames_,
        para_handle_->means_list_,
        para_handle_->vars_list_);
    InitCache();
}
void ParaformerOnline::InitOnline(
        knf::FbankOptions &fbank_opts,
        std::shared_ptr<Ort::Session> &encoder_session,
        std::shared_ptr<Ort::Session> &decoder_session,
        vector<const char*> &en_szInputNames,
        vector<const char*> &en_szOutputNames,
        vector<const char*> &de_szInputNames,
        vector<const char*> &de_szOutputNames,
        vector<float> &means_list,
        vector<float> &vars_list){
    fbank_opts_ = fbank_opts;
    encoder_session_ = encoder_session;
    decoder_session_ = decoder_session;
    en_szInputNames_ = en_szInputNames;
    en_szOutputNames_ = en_szOutputNames;
    de_szInputNames_ = de_szInputNames;
    de_szOutputNames_ = de_szOutputNames;
    means_list_ = means_list;
    vars_list_ = vars_list;
    frame_length = para_handle_->frame_length;
    frame_shift = para_handle_->frame_shift;
    n_mels = para_handle_->n_mels;
    lfr_m = para_handle_->lfr_m;
    lfr_n = para_handle_->lfr_n;
    encoder_size = para_handle_->encoder_size;
    fsmn_layers = para_handle_->fsmn_layers;
    fsmn_lorder = para_handle_->fsmn_lorder;
    fsmn_dims = para_handle_->fsmn_dims;
    cif_threshold = para_handle_->cif_threshold;
    tail_alphas = para_handle_->tail_alphas;
    // other vars
    sqrt_factor = std::sqrt(encoder_size);
    for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
        fsmn_init_cache_.emplace_back(0);
    }
    chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
}
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
                               std::vector<float> &waves) {
    knf::OnlineFbank fbank(fbank_opts_);
    // cache merge
    waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
    int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
    // Send the audio after the last frame shift position to the cache
    input_cache_.clear();
    input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
    if (frame_number == 0) {
        return;
    }
    // Delete audio that haven't undergone fbank processing
    waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
    std::vector<float> buf(waves.size());
    for (int32_t i = 0; i != waves.size(); ++i) {
        buf[i] = waves[i] * 32768;
    }
    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
    int32_t frames = fbank.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank.GetFrame(i);
        vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
        wav_feats.emplace_back(frame_vector);
    }
}
void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
                                 vector<float> &waves, bool input_finished) {
    FbankKaldi(sample_rate, wav_feats, waves);
    // cache deal & online lfr,cmvn
    if (wav_feats.size() > 0) {
        if (!reserve_waveforms_.empty()) {
        waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
        }
        if (lfr_splice_cache_.empty()) {
        for (int i = 0; i < (lfr_m - 1) / 2; i++) {
            lfr_splice_cache_.emplace_back(wav_feats[0]);
        }
        }
        if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
        wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
        int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
        int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
        int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
        int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
        reserve_waveforms_.clear();
        reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                    waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
                                    waves.begin() + frame_from_waves * frame_shift_sample_length_);
        int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
        waves.erase(waves.begin() + sample_length, waves.end());
        } else {
        reserve_waveforms_.clear();
        reserve_waveforms_.insert(reserve_waveforms_.begin(),
                                    waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
        lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
        }
    } else {
        if (input_finished) {
            if (!reserve_waveforms_.empty()) {
                waves = reserve_waveforms_;
            }
            wav_feats = lfr_splice_cache_;
            if(wav_feats.size() == 0){
                LOG(ERROR) << "wav_feats's size is 0";
            }else{
                OnlineLfrCmvn(wav_feats, input_finished);
            }
        }
    }
    if(input_finished){
        ResetCache();
    }
}
int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
    vector<vector<float>> out_feats;
    int T = wav_feats.size();
    int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
    int lfr_splice_frame_idxs = T_lrf;
    vector<float> p;
    for (int i = 0; i < T_lrf; i++) {
        if (lfr_m <= T - i * lfr_n) {
            for (int j = 0; j < lfr_m; j++) {
                p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        } else {
            if (input_finished) {
                int num_padding = lfr_m - (T - i * lfr_n);
                for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
                    p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
                }
                for (int j = 0; j < num_padding; j++) {
                    p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
                }
                out_feats.emplace_back(p);
            } else {
                lfr_splice_frame_idxs = i;
                break;
            }
        }
    }
    lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
    lfr_splice_cache_.clear();
    lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list_.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
        }
    }
    wav_feats = out_feats;
    return lfr_splice_frame_idxs;
}
void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
{
    int start_idx = start_idx_cache_;
    start_idx_cache_ += timesteps;
    int mm = start_idx_cache_;
    int i;
    float scale = -0.0330119726594128;
    std::vector<float> tmp(mm*feat_dim);
    for (i = 0; i < feat_dim/2; i++) {
        float tmptime = exp(i * scale);
        int j;
        for (j = 0; j < mm; j++) {
            int sin_idx = j * feat_dim + i;
            int cos_idx = j * feat_dim + i + feat_dim/2;
            float coe = tmptime * (j + 1);
            tmp[sin_idx] = sin(coe);
            tmp[cos_idx] = cos(coe);
        }
    }
    for (i = start_idx; i < start_idx + timesteps; i++) {
        for (int j = 0; j < feat_dim; j++) {
            wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
        }
    }
}
void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
{
    try{
        int hidden_size = 0;
        if(hidden.size() > 0){
            hidden_size = hidden[0].size();
        }
        // cache
        int i,j;
        int chunk_size_pre = chunk_size[0];
        for (i = 0; i < chunk_size_pre; i++)
            alphas[i] = 0.0;
        int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
        for (i = chunk_size_suf; i < alphas.size(); i++){
            alphas[i] = 0.0;
        }
        if(hidden_cache_.size()>0){
            hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
            alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
            hidden_cache_.clear();
            alphas_cache_.clear();
        }
        if (is_last_chunk) {
            std::vector<float> tail_hidden(hidden_size, 0);
            hidden.emplace_back(tail_hidden);
            alphas.emplace_back(tail_alphas);
        }
        float intergrate = 0.0;
        int len_time = alphas.size();
        std::vector<float> frames(hidden_size, 0);
        std::vector<float> list_fire;
        for (i = 0; i < len_time; i++) {
            float alpha = alphas[i];
            if (alpha + intergrate < cif_threshold) {
                intergrate += alpha;
                list_fire.emplace_back(intergrate);
                for (j = 0; j < hidden_size; j++) {
                    frames[j] += alpha * hidden[i][j];
                }
            } else {
                for (j = 0; j < hidden_size; j++) {
                    frames[j] += (cif_threshold - intergrate) * hidden[i][j];
                }
                std::vector<float> frames_cp(frames);
                list_frame.emplace_back(frames_cp);
                intergrate += alpha;
                list_fire.emplace_back(intergrate);
                intergrate -= cif_threshold;
                for (j = 0; j < hidden_size; j++) {
                    frames[j] = intergrate * hidden[i][j];
                }
            }
        }
        // cache
        alphas_cache_.emplace_back(intergrate);
        if (intergrate > 0.0) {
            std::vector<float> hidden_cache(hidden_size, 0);
            for (i = 0; i < hidden_size; i++) {
                hidden_cache[i] = frames[i] / intergrate;
            }
            hidden_cache_.emplace_back(hidden_cache);
        } else {
            std::vector<float> frames_cp(frames);
            hidden_cache_.emplace_back(frames_cp);
        }
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
    }
}
void ParaformerOnline::InitCache(){
    start_idx_cache_ = 0;
    is_first_chunk = true;
    is_last_chunk = false;
    hidden_cache_.clear();
    alphas_cache_.clear();
    feats_cache_.clear();
    decoder_onnx.clear();
    // cif cache
    std::vector<float> hidden_cache(encoder_size, 0);
    hidden_cache_.emplace_back(hidden_cache);
    alphas_cache_.emplace_back(0);
    // feats
    std::vector<float> feat_cache(feat_dims, 0);
    for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
        feats_cache_.emplace_back(feat_cache);
    }
    // fsmn cache
#ifdef _WIN_X86
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
    const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
    for(int l=0; l<fsmn_layers; l++){
        Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
            m_memoryInfo,
            fsmn_init_cache_.data(),
            fsmn_init_cache_.size(),
            fsmn_shape_,
            3);
        decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
    }
};
void ParaformerOnline::Reset()
{
    InitCache();
}
void ParaformerOnline::ResetCache() {
    reserve_waveforms_.clear();
    input_cache_.clear();
    lfr_splice_cache_.clear();
}
void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
    wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
    if(input_finished){
        feats_cache_.clear();
        feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
        if(!is_last_chunk){
            int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
            std::vector<float> tmp(feat_dims, 0);
            for(int i=0; i<padding_length; i++){
                wav_feats.emplace_back(feat_dims);
            }
        }
    }else{
        feats_cache_.clear();
        feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());
    }
}
string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
{
    string result;
    try{
        int32_t num_frames = chunk_feats.size();
    #ifdef _WIN_X86
            Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    #else
            Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    #endif
        const int64_t input_shape_[3] = {1, num_frames, feat_dims};
        std::vector<float> wav_feats;
        for (const auto &chunk_feat: chunk_feats) {
            wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
        }
        Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
            m_memoryInfo,
            wav_feats.data(),
            wav_feats.size(),
            input_shape_,
            3);
        const int64_t paraformer_length_shape[1] = {1};
        std::vector<int32_t> paraformer_length;
        paraformer_length.emplace_back(num_frames);
        Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
            m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
        std::vector<Ort::Value> input_onnx;
        input_onnx.emplace_back(std::move(onnx_feats));
        input_onnx.emplace_back(std::move(onnx_feats_len));
        auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
        // get enc_vec
        std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
        float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
        std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
        for (int i = 0; i < enc_shape[1]; i++) {
            for (int j = 0; j < enc_shape[2]; j++) {
                enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
            }
        }
        // get alpha_vec
        std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
        float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
        std::vector<float> alpha_vec(alpha_shape[1]);
        for (int i = 0; i < alpha_shape[1]; i++) {
            alpha_vec[i] = alpha_data[i];
        }
        std::vector<std::vector<float>> list_frame;
        CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
        if(list_frame.size()>0){
            // enc
            decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
            // enc_lens
            decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
            // acoustic_embeds
            const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
            std::vector<float> emb_input;
            for (const auto &list_frame_: list_frame) {
                emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
            }
            Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
                m_memoryInfo,
                emb_input.data(),
                emb_input.size(),
                emb_shape_,
                3);
            decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
            // acoustic_embeds_len
            const int64_t emb_length_shape[1] = {1};
            std::vector<int32_t> emb_length;
            emb_length.emplace_back(list_frame.size());
            Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
                m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
            decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
            auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
            // fsmn cache
            try{
                decoder_onnx.clear();
            }catch (std::exception const &e)
            {
                LOG(ERROR)<<e.what();
                return result;
            }
            for(int l=0;l<fsmn_layers;l++){
                decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
            }
            std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
            float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
            result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
        }
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
        return result;
    }
    return result;
}
string ParaformerOnline::Forward(float* din, int len, bool input_finished)
{
    std::vector<std::vector<float>> wav_feats;
    std::vector<float> waves(din, din+len);
    string result="";
    try{
        if(len <16*60 && input_finished && !is_first_chunk){
            is_last_chunk = true;
            wav_feats = feats_cache_;
            result = ForwardChunk(wav_feats, is_last_chunk);
            // reset
            ResetCache();
            Reset();
            return result;
        }
        if(is_first_chunk){
            is_first_chunk = false;
        }
        ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
        if(wav_feats.size() == 0){
            return result;
        }
        for (auto& row : wav_feats) {
            for (auto& val : row) {
                val *= sqrt_factor;
            }
        }
        GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
        if(input_finished){
            if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
                is_last_chunk = true;
                AddOverlapChunk(wav_feats, input_finished);
            }else{
                // first chunk
                std::vector<std::vector<float>> first_chunk;
                first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
                AddOverlapChunk(first_chunk, input_finished);
                string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
                // last chunk
                is_last_chunk = true;
                std::vector<std::vector<float>> last_chunk;
                last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
                AddOverlapChunk(last_chunk, input_finished);
                string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
                result = str_first_chunk+str_last_chunk;
                // reset
                ResetCache();
                Reset();
                return result;
            }
        }else{
            AddOverlapChunk(wav_feats, input_finished);
        }
        result = ForwardChunk(wav_feats, is_last_chunk);
        if(input_finished){
            // reset
            ResetCache();
            Reset();
        }
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
        return result;
    }
    return result;
}
ParaformerOnline::~ParaformerOnline()
{
}
string ParaformerOnline::Rescoring()
{
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
} // namespace funasr