From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode
---
runtime/onnxruntime/src/sensevoice-small.h | 12 +++++++++---
1 files changed, 9 insertions(+), 3 deletions(-)
diff --git a/runtime/onnxruntime/src/sensevoice-small.h b/runtime/onnxruntime/src/sensevoice-small.h
index f987f38..75cbc92 100644
--- a/runtime/onnxruntime/src/sensevoice-small.h
+++ b/runtime/onnxruntime/src/sensevoice-small.h
@@ -12,12 +12,14 @@
class SenseVoiceSmall : public Model {
private:
Vocab* vocab = nullptr;
+ Vocab* online_vocab = nullptr;
Vocab* lm_vocab = nullptr;
SegDict* seg_dict = nullptr;
PhoneSet* phone_set_ = nullptr;
const float scale = 1.0;
void LoadConfigFromYaml(const char* filename);
+ void LoadOnlineConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
@@ -34,9 +36,10 @@
~SenseVoiceSmall();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// online
- // void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+ void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
- // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config,
+ const std::string &token_file, const std::string &online_token_file, int thread_num);
// void InitHwCompiler(const std::string &hw_model, int thread_num);
// void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
@@ -44,7 +47,8 @@
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, std::string svs_lang="auto", bool svs_itn=true, int batch_in=1);
string CTCSearch( float * in, std::vector<int32_t> paraformer_length, std::vector<int64_t> outputShape);
-
+ string GreedySearch( float* in, int n_len, int64_t token_nums,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
@@ -100,6 +104,8 @@
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
int blank_id = 0;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
//dict
std::map<std::string, int> lid_map = {
{"auto", 0},
--
Gitblit v1.9.1