From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode
---
runtime/onnxruntime/src/paraformer-online.h | 23 +++++++++++++++++------
1 files changed, 17 insertions(+), 6 deletions(-)
diff --git a/runtime/onnxruntime/src/paraformer-online.h b/runtime/onnxruntime/src/paraformer-online.h
index 8c9bb88..8ab473d 100644
--- a/runtime/onnxruntime/src/paraformer-online.h
+++ b/runtime/onnxruntime/src/paraformer-online.h
@@ -38,7 +38,18 @@
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
- vector<float> &vars_list);
+ vector<float> &vars_list,
+ int frame_length_,
+ int frame_shift_,
+ int n_mels_,
+ int lfr_m_,
+ int lfr_n_,
+ int encoder_size_,
+ int fsmn_layers_,
+ int fsmn_lorder_,
+ int fsmn_dims_,
+ float cif_threshold_,
+ float tail_alphas_);
void StartUtterance()
{
@@ -48,8 +59,8 @@
{
}
- Paraformer* para_handle_ = nullptr;
- // from para_handle_
+ Model* offline_handle_ = nullptr;
+ // from offline_handle_
knf::FbankOptions fbank_opts_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
@@ -60,7 +71,7 @@
vector<const char*> de_szOutputNames_;
vector<float> means_list_;
vector<float> vars_list_;
- // configs from para_handle_
+ // configs from offline_handle_
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
@@ -100,7 +111,7 @@
double sqrt_factor;
public:
- ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
+ ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type=MODEL_PARA);
~ParaformerOnline();
void Reset();
void ResetCache();
@@ -112,7 +123,7 @@
string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
string Rescoring();
- int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
+ int GetAsrSampleRate() { return offline_handle_->GetAsrSampleRate(); };
// 2pass
std::string online_res;
--
Gitblit v1.9.1