游雁
2024-10-29 7edad6fba36a7527c1857a38b77a0277e8fde582
runtime/onnxruntime/src/paraformer-online.cpp
@@ -9,18 +9,55 @@
namespace funasr {
ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
    InitOnline(
        para_handle_->fbank_opts_,
        para_handle_->encoder_session_,
        para_handle_->decoder_session_,
        para_handle_->en_szInputNames_,
        para_handle_->en_szOutputNames_,
        para_handle_->de_szInputNames_,
        para_handle_->de_szOutputNames_,
        para_handle_->means_list_,
        para_handle_->vars_list_);
ParaformerOnline::ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type)
:offline_handle_(std::move(offline_handle)),chunk_size(chunk_size),session_options_{}{
    if(model_type == MODEL_PARA){
        Paraformer* para_handle = dynamic_cast<Paraformer*>(offline_handle_);
        InitOnline(
        para_handle->fbank_opts_,
        para_handle->encoder_session_,
        para_handle->decoder_session_,
        para_handle->en_szInputNames_,
        para_handle->en_szOutputNames_,
        para_handle->de_szInputNames_,
        para_handle->de_szOutputNames_,
        para_handle->means_list_,
        para_handle->vars_list_,
        para_handle->frame_length,
        para_handle->frame_shift,
        para_handle->n_mels,
        para_handle->lfr_m,
        para_handle->lfr_n,
        para_handle->encoder_size,
        para_handle->fsmn_layers,
        para_handle->fsmn_lorder,
        para_handle->fsmn_dims,
        para_handle->cif_threshold,
        para_handle->tail_alphas);
    }else if(model_type == MODEL_SVS){
        SenseVoiceSmall* svs_handle = dynamic_cast<SenseVoiceSmall*>(offline_handle_);
        InitOnline(
        svs_handle->fbank_opts_,
        svs_handle->encoder_session_,
        svs_handle->decoder_session_,
        svs_handle->en_szInputNames_,
        svs_handle->en_szOutputNames_,
        svs_handle->de_szInputNames_,
        svs_handle->de_szOutputNames_,
        svs_handle->means_list_,
        svs_handle->vars_list_,
        svs_handle->frame_length,
        svs_handle->frame_shift,
        svs_handle->n_mels,
        svs_handle->lfr_m,
        svs_handle->lfr_n,
        svs_handle->encoder_size,
        svs_handle->fsmn_layers,
        svs_handle->fsmn_lorder,
        svs_handle->fsmn_dims,
        svs_handle->cif_threshold,
        svs_handle->tail_alphas);
    }
    InitCache();
}
@@ -33,7 +70,18 @@
        vector<const char*> &de_szInputNames,
        vector<const char*> &de_szOutputNames,
        vector<float> &means_list,
        vector<float> &vars_list){
        vector<float> &vars_list,
        int frame_length_,
        int frame_shift_,
        int n_mels_,
        int lfr_m_,
        int lfr_n_,
        int encoder_size_,
        int fsmn_layers_,
        int fsmn_lorder_,
        int fsmn_dims_,
        float cif_threshold_,
        float tail_alphas_){
    fbank_opts_ = fbank_opts;
    encoder_session_ = encoder_session;
    decoder_session_ = decoder_session;
@@ -44,27 +92,27 @@
    means_list_ = means_list;
    vars_list_ = vars_list;
    frame_length = para_handle_->frame_length;
    frame_shift = para_handle_->frame_shift;
    n_mels = para_handle_->n_mels;
    lfr_m = para_handle_->lfr_m;
    lfr_n = para_handle_->lfr_n;
    encoder_size = para_handle_->encoder_size;
    fsmn_layers = para_handle_->fsmn_layers;
    fsmn_lorder = para_handle_->fsmn_lorder;
    fsmn_dims = para_handle_->fsmn_dims;
    cif_threshold = para_handle_->cif_threshold;
    tail_alphas = para_handle_->tail_alphas;
    frame_length = frame_length_;
    frame_shift = frame_shift_;
    n_mels = n_mels_;
    lfr_m = lfr_m_;
    lfr_n = lfr_n_;
    encoder_size = encoder_size_;
    fsmn_layers = fsmn_layers_;
    fsmn_lorder = fsmn_lorder_;
    fsmn_dims = fsmn_dims_;
    cif_threshold = cif_threshold_;
    tail_alphas = tail_alphas_;
    // other vars
    sqrt_factor = std::sqrt(encoder_size);
    for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
        fsmn_init_cache_.emplace_back(0);
    }
    chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
    chunk_len = chunk_size[1]*frame_shift*lfr_n*offline_handle_->GetAsrSampleRate()/1000;
    frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
    frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
    frame_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_length;
    frame_shift_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_shift;
}
@@ -464,7 +512,7 @@
            std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
            float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
            result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
            result = offline_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
        }
    }catch (std::exception const &e)
    {
@@ -493,7 +541,7 @@
        if(is_first_chunk){
            is_first_chunk = false;
        }
        ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
        ExtractFeats(offline_handle_->GetAsrSampleRate(), wav_feats, waves, input_finished);
        if(wav_feats.size() == 0){
            return result;
        }