From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode

---
 runtime/onnxruntime/src/sensevoice-small.h |   12 +++++++++---
 1 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/runtime/onnxruntime/src/sensevoice-small.h b/runtime/onnxruntime/src/sensevoice-small.h
index f987f38..75cbc92 100644
--- a/runtime/onnxruntime/src/sensevoice-small.h
+++ b/runtime/onnxruntime/src/sensevoice-small.h
@@ -12,12 +12,14 @@
     class SenseVoiceSmall : public Model {
     private:
         Vocab* vocab = nullptr;
+        Vocab* online_vocab = nullptr;
         Vocab* lm_vocab = nullptr;
         SegDict* seg_dict = nullptr;
         PhoneSet* phone_set_ = nullptr;
         const float scale = 1.0;
 
         void LoadConfigFromYaml(const char* filename);
+        void LoadOnlineConfigFromYaml(const char* filename);
         void LoadCmvn(const char *filename);
         void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
 
@@ -34,9 +36,10 @@
         ~SenseVoiceSmall();
         void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
         // online
-        // void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
         // 2pass
-        // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, 
+            const std::string &token_file, const std::string &online_token_file, int thread_num);
         // void InitHwCompiler(const std::string &hw_model, int thread_num);
         // void InitSegDict(const std::string &seg_dict_model);
         std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
@@ -44,7 +47,8 @@
         void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
         std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, std::string svs_lang="auto", bool svs_itn=true, int batch_in=1);
         string CTCSearch( float * in, std::vector<int32_t> paraformer_length, std::vector<int64_t> outputShape);
-
+        string GreedySearch( float* in, int n_len, int64_t token_nums,
+                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
         string Rescoring();
         string GetLang(){return language;};
         int GetAsrSampleRate() { return asr_sample_rate; };
@@ -100,6 +104,8 @@
         int asr_sample_rate = MODEL_SAMPLE_RATE;
         int batch_size_ = 1;
         int blank_id = 0;
+        float cif_threshold = 1.0;
+        float tail_alphas = 0.45;
         //dict
         std::map<std::string, int> lid_map = {
             {"auto", 0},

--
Gitblit v1.9.1