From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode
---
runtime/onnxruntime/src/paraformer-online.cpp | 106 ++++++++++++++++++++++++++++++++++++++--------------
1 files changed, 77 insertions(+), 29 deletions(-)
diff --git a/runtime/onnxruntime/src/paraformer-online.cpp b/runtime/onnxruntime/src/paraformer-online.cpp
index 55a4fd1..88951aa 100644
--- a/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/runtime/onnxruntime/src/paraformer-online.cpp
@@ -9,18 +9,55 @@
namespace funasr {
-ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
-:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
- InitOnline(
- para_handle_->fbank_opts_,
- para_handle_->encoder_session_,
- para_handle_->decoder_session_,
- para_handle_->en_szInputNames_,
- para_handle_->en_szOutputNames_,
- para_handle_->de_szInputNames_,
- para_handle_->de_szOutputNames_,
- para_handle_->means_list_,
- para_handle_->vars_list_);
+ParaformerOnline::ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type)
+:offline_handle_(std::move(offline_handle)),chunk_size(chunk_size),session_options_{}{
+ if(model_type == MODEL_PARA){
+ Paraformer* para_handle = dynamic_cast<Paraformer*>(offline_handle_);
+ InitOnline(
+ para_handle->fbank_opts_,
+ para_handle->encoder_session_,
+ para_handle->decoder_session_,
+ para_handle->en_szInputNames_,
+ para_handle->en_szOutputNames_,
+ para_handle->de_szInputNames_,
+ para_handle->de_szOutputNames_,
+ para_handle->means_list_,
+ para_handle->vars_list_,
+ para_handle->frame_length,
+ para_handle->frame_shift,
+ para_handle->n_mels,
+ para_handle->lfr_m,
+ para_handle->lfr_n,
+ para_handle->encoder_size,
+ para_handle->fsmn_layers,
+ para_handle->fsmn_lorder,
+ para_handle->fsmn_dims,
+ para_handle->cif_threshold,
+ para_handle->tail_alphas);
+ }else if(model_type == MODEL_SVS){
+ SenseVoiceSmall* svs_handle = dynamic_cast<SenseVoiceSmall*>(offline_handle_);
+ InitOnline(
+ svs_handle->fbank_opts_,
+ svs_handle->encoder_session_,
+ svs_handle->decoder_session_,
+ svs_handle->en_szInputNames_,
+ svs_handle->en_szOutputNames_,
+ svs_handle->de_szInputNames_,
+ svs_handle->de_szOutputNames_,
+ svs_handle->means_list_,
+ svs_handle->vars_list_,
+ svs_handle->frame_length,
+ svs_handle->frame_shift,
+ svs_handle->n_mels,
+ svs_handle->lfr_m,
+ svs_handle->lfr_n,
+ svs_handle->encoder_size,
+ svs_handle->fsmn_layers,
+ svs_handle->fsmn_lorder,
+ svs_handle->fsmn_dims,
+ svs_handle->cif_threshold,
+ svs_handle->tail_alphas);
+ }
InitCache();
}
@@ -33,7 +70,18 @@
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
- vector<float> &vars_list){
+ vector<float> &vars_list,
+ int frame_length_,
+ int frame_shift_,
+ int n_mels_,
+ int lfr_m_,
+ int lfr_n_,
+ int encoder_size_,
+ int fsmn_layers_,
+ int fsmn_lorder_,
+ int fsmn_dims_,
+ float cif_threshold_,
+ float tail_alphas_){
fbank_opts_ = fbank_opts;
encoder_session_ = encoder_session;
decoder_session_ = decoder_session;
@@ -44,27 +92,27 @@
means_list_ = means_list;
vars_list_ = vars_list;
- frame_length = para_handle_->frame_length;
- frame_shift = para_handle_->frame_shift;
- n_mels = para_handle_->n_mels;
- lfr_m = para_handle_->lfr_m;
- lfr_n = para_handle_->lfr_n;
- encoder_size = para_handle_->encoder_size;
- fsmn_layers = para_handle_->fsmn_layers;
- fsmn_lorder = para_handle_->fsmn_lorder;
- fsmn_dims = para_handle_->fsmn_dims;
- cif_threshold = para_handle_->cif_threshold;
- tail_alphas = para_handle_->tail_alphas;
+ frame_length = frame_length_;
+ frame_shift = frame_shift_;
+ n_mels = n_mels_;
+ lfr_m = lfr_m_;
+ lfr_n = lfr_n_;
+ encoder_size = encoder_size_;
+ fsmn_layers = fsmn_layers_;
+ fsmn_lorder = fsmn_lorder_;
+ fsmn_dims = fsmn_dims_;
+ cif_threshold = cif_threshold_;
+ tail_alphas = tail_alphas_;
// other vars
sqrt_factor = std::sqrt(encoder_size);
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
fsmn_init_cache_.emplace_back(0);
}
- chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
+ chunk_len = chunk_size[1]*frame_shift*lfr_n*offline_handle_->GetAsrSampleRate()/1000;
- frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
- frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
+ frame_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_length;
+ frame_shift_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_shift;
}
@@ -464,7 +512,7 @@
std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
- result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+ result = offline_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
}
}catch (std::exception const &e)
{
@@ -493,7 +541,7 @@
if(is_first_chunk){
is_first_chunk = false;
}
- ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
+ ExtractFeats(offline_handle_->GetAsrSampleRate(), wav_feats, waves, input_finished);
if(wav_feats.size() == 0){
return result;
}
--
Gitblit v1.9.1