From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode

---
 runtime/onnxruntime/src/paraformer-online.cpp |  106 ++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 77 insertions(+), 29 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-online.cpp b/runtime/onnxruntime/src/paraformer-online.cpp
index 55a4fd1..88951aa 100644
--- a/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/runtime/onnxruntime/src/paraformer-online.cpp
@@ -9,18 +9,55 @@
 
 namespace funasr {
 
-ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
-:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
-    InitOnline(
-        para_handle_->fbank_opts_,
-        para_handle_->encoder_session_,
-        para_handle_->decoder_session_,
-        para_handle_->en_szInputNames_,
-        para_handle_->en_szOutputNames_,
-        para_handle_->de_szInputNames_,
-        para_handle_->de_szOutputNames_,
-        para_handle_->means_list_,
-        para_handle_->vars_list_);
+ParaformerOnline::ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type)
+:offline_handle_(std::move(offline_handle)),chunk_size(chunk_size),session_options_{}{
+    if(model_type == MODEL_PARA){
+        Paraformer* para_handle = dynamic_cast<Paraformer*>(offline_handle_);
+        InitOnline(
+        para_handle->fbank_opts_,
+        para_handle->encoder_session_,
+        para_handle->decoder_session_,
+        para_handle->en_szInputNames_,
+        para_handle->en_szOutputNames_,
+        para_handle->de_szInputNames_,
+        para_handle->de_szOutputNames_,
+        para_handle->means_list_,
+        para_handle->vars_list_,
+        para_handle->frame_length,
+        para_handle->frame_shift,
+        para_handle->n_mels,
+        para_handle->lfr_m,
+        para_handle->lfr_n,
+        para_handle->encoder_size,
+        para_handle->fsmn_layers,
+        para_handle->fsmn_lorder,
+        para_handle->fsmn_dims,
+        para_handle->cif_threshold,
+        para_handle->tail_alphas);
+    }else if(model_type == MODEL_SVS){
+        SenseVoiceSmall* svs_handle = dynamic_cast<SenseVoiceSmall*>(offline_handle_);
+        InitOnline(
+        svs_handle->fbank_opts_,
+        svs_handle->encoder_session_,
+        svs_handle->decoder_session_,
+        svs_handle->en_szInputNames_,
+        svs_handle->en_szOutputNames_,
+        svs_handle->de_szInputNames_,
+        svs_handle->de_szOutputNames_,
+        svs_handle->means_list_,
+        svs_handle->vars_list_,
+        svs_handle->frame_length,
+        svs_handle->frame_shift,
+        svs_handle->n_mels,
+        svs_handle->lfr_m,
+        svs_handle->lfr_n,
+        svs_handle->encoder_size,
+        svs_handle->fsmn_layers,
+        svs_handle->fsmn_lorder,
+        svs_handle->fsmn_dims,
+        svs_handle->cif_threshold,
+        svs_handle->tail_alphas);
+    }
     InitCache();
 }
 
@@ -33,7 +70,18 @@
         vector<const char*> &de_szInputNames,
         vector<const char*> &de_szOutputNames,
         vector<float> &means_list,
-        vector<float> &vars_list){
+        vector<float> &vars_list,
+        int frame_length_,
+        int frame_shift_,
+        int n_mels_,
+        int lfr_m_,
+        int lfr_n_,
+        int encoder_size_,
+        int fsmn_layers_,
+        int fsmn_lorder_,
+        int fsmn_dims_,
+        float cif_threshold_,
+        float tail_alphas_){
     fbank_opts_ = fbank_opts;
     encoder_session_ = encoder_session;
     decoder_session_ = decoder_session;
@@ -44,27 +92,27 @@
     means_list_ = means_list;
     vars_list_ = vars_list;
 
-    frame_length = para_handle_->frame_length;
-    frame_shift = para_handle_->frame_shift;
-    n_mels = para_handle_->n_mels;
-    lfr_m = para_handle_->lfr_m;
-    lfr_n = para_handle_->lfr_n;
-    encoder_size = para_handle_->encoder_size;
-    fsmn_layers = para_handle_->fsmn_layers;
-    fsmn_lorder = para_handle_->fsmn_lorder;
-    fsmn_dims = para_handle_->fsmn_dims;
-    cif_threshold = para_handle_->cif_threshold;
-    tail_alphas = para_handle_->tail_alphas;
+    frame_length = frame_length_;
+    frame_shift = frame_shift_;
+    n_mels = n_mels_;
+    lfr_m = lfr_m_;
+    lfr_n = lfr_n_;
+    encoder_size = encoder_size_;
+    fsmn_layers = fsmn_layers_;
+    fsmn_lorder = fsmn_lorder_;
+    fsmn_dims = fsmn_dims_;
+    cif_threshold = cif_threshold_;
+    tail_alphas = tail_alphas_;
 
     // other vars
     sqrt_factor = std::sqrt(encoder_size);
     for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
         fsmn_init_cache_.emplace_back(0);
     }
-    chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
+    chunk_len = chunk_size[1]*frame_shift*lfr_n*offline_handle_->GetAsrSampleRate()/1000;
 
-    frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
-    frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
+    frame_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_length;
+    frame_shift_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_shift;
 
 }
 
@@ -464,7 +512,7 @@
 
             std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
             float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
-            result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+            result = offline_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
         }
     }catch (std::exception const &e)
     {
@@ -493,7 +541,7 @@
         if(is_first_chunk){
             is_first_chunk = false;
         }
-        ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
+        ExtractFeats(offline_handle_->GetAsrSampleRate(), wav_feats, waves, input_finished);
         if(wav_feats.size() == 0){
             return result;
         }

--
Gitblit v1.9.1