From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 runtime/onnxruntime/src/sensevoice-small.cpp |  188 ++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 168 insertions(+), 20 deletions(-)

diff --git a/runtime/onnxruntime/src/sensevoice-small.cpp b/runtime/onnxruntime/src/sensevoice-small.cpp
index 10eb907..5cb1042 100644
--- a/runtime/onnxruntime/src/sensevoice-small.cpp
+++ b/runtime/onnxruntime/src/sensevoice-small.cpp
@@ -42,28 +42,149 @@
         exit(-1);
     }
 
-    string strName;
-    GetInputName(m_session_.get(), strName);
-    m_strInputNames.push_back(strName.c_str());
-    GetInputName(m_session_.get(), strName,1);
-    m_strInputNames.push_back(strName);
-    GetInputName(m_session_.get(), strName,2);
-    m_strInputNames.push_back(strName);
-    GetInputName(m_session_.get(), strName,3);
-    m_strInputNames.push_back(strName);
-
-    size_t numOutputNodes = m_session_->GetOutputCount();
-    for(int index=0; index<numOutputNodes; index++){
-        GetOutputName(m_session_.get(), strName, index);
-        m_strOutputNames.push_back(strName);
-    }
-
-    for (auto& item : m_strInputNames)
-        m_szInputNames.push_back(item.c_str());
-    for (auto& item : m_strOutputNames)
-        m_szOutputNames.push_back(item.c_str());
+    GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
+    GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
     vocab = new Vocab(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
+}
+
+// online
+void SenseVoiceSmall::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
+    
+    LoadOnlineConfigFromYaml(am_config.c_str());
+    // knf options
+    fbank_opts_.frame_opts.dither = 0;
+    fbank_opts_.mel_opts.num_bins = n_mels;
+    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
+    fbank_opts_.frame_opts.window_type = window_type;
+    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+    fbank_opts_.frame_opts.frame_length_ms = frame_length;
+    fbank_opts_.energy_floor = 0;
+    fbank_opts_.mel_opts.debug_mel = false;
+
+    // session_options_.SetInterOpNumThreads(1);
+    session_options_.SetIntraOpNumThreads(thread_num);
+    session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+    // DisableCpuMemArena can improve performance
+    session_options_.DisableCpuMemArena();
+
+    try {
+        encoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(en_model).c_str(), session_options_);
+        LOG(INFO) << "Successfully load model from " << en_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am encoder model: " << e.what();
+        exit(-1);
+    }
+
+    try {
+        decoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(de_model).c_str(), session_options_);
+        LOG(INFO) << "Successfully load model from " << de_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am decoder model: " << e.what();
+        exit(-1);
+    }
+
+    // encoder
+    string strName;
+    GetInputName(encoder_session_.get(), strName);
+    en_strInputNames.push_back(strName.c_str());
+    GetInputName(encoder_session_.get(), strName,1);
+    en_strInputNames.push_back(strName);
+    
+    GetOutputName(encoder_session_.get(), strName);
+    en_strOutputNames.push_back(strName);
+    GetOutputName(encoder_session_.get(), strName,1);
+    en_strOutputNames.push_back(strName);
+    GetOutputName(encoder_session_.get(), strName,2);
+    en_strOutputNames.push_back(strName);
+
+    for (auto& item : en_strInputNames)
+        en_szInputNames_.push_back(item.c_str());
+    for (auto& item : en_strOutputNames)
+        en_szOutputNames_.push_back(item.c_str());
+
+    // decoder
+    int de_input_len = 4 + fsmn_layers;
+    int de_out_len = 2 + fsmn_layers;
+    for(int i=0;i<de_input_len; i++){
+        GetInputName(decoder_session_.get(), strName, i);
+        de_strInputNames.push_back(strName.c_str());
+    }
+
+    for(int i=0;i<de_out_len; i++){
+        GetOutputName(decoder_session_.get(), strName,i);
+        de_strOutputNames.push_back(strName);
+    }
+
+    for (auto& item : de_strInputNames)
+        de_szInputNames_.push_back(item.c_str());
+    for (auto& item : de_strOutputNames)
+        de_szOutputNames_.push_back(item.c_str());
+
+    online_vocab = new Vocab(token_file.c_str());
+    phone_set_ = new PhoneSet(token_file.c_str());
+    LoadCmvn(am_cmvn.c_str());
+}
+
+// 2pass
+void SenseVoiceSmall::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, 
+    const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){
+    // online
+    InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num);
+
+    // offline
+    try {
+        m_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(am_model).c_str(), session_options_);
+        LOG(INFO) << "Successfully load model from " << am_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am onnx model: " << e.what();
+        exit(-1);
+    }
+
+    GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
+    GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
+    vocab = new Vocab(token_file.c_str());
+}
+
+void SenseVoiceSmall::LoadOnlineConfigFromYaml(const char* filename){
+
+    YAML::Node config;
+    try{
+        config = YAML::LoadFile(filename);
+    }catch(exception const &e){
+        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
+    try{
+        YAML::Node frontend_conf = config["frontend_conf"];
+        YAML::Node encoder_conf = config["encoder_conf"];
+        YAML::Node decoder_conf = config["decoder_conf"];
+        YAML::Node predictor_conf = config["predictor_conf"];
+
+        this->window_type = frontend_conf["window"].as<string>();
+        this->n_mels = frontend_conf["n_mels"].as<int>();
+        this->frame_length = frontend_conf["frame_length"].as<int>();
+        this->frame_shift = frontend_conf["frame_shift"].as<int>();
+        this->lfr_m = frontend_conf["lfr_m"].as<int>();
+        this->lfr_n = frontend_conf["lfr_n"].as<int>();
+
+        this->encoder_size = encoder_conf["output_size"].as<int>();
+        this->fsmn_dims = encoder_conf["output_size"].as<int>();
+
+        this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
+        this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
+
+        this->cif_threshold = predictor_conf["threshold"].as<double>();
+        this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
+
+        this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+
+    }catch(exception const &e){
+        LOG(ERROR) << "Error when load argument from vad config YAML.";
+        exit(-1);
+    }
 }
 
 void SenseVoiceSmall::LoadConfigFromYaml(const char* filename){
@@ -101,6 +222,9 @@
 {
     if(vocab){
         delete vocab;
+    }
+    if(online_vocab){
+        delete online_vocab;
     }
     if(lm_vocab){
         delete lm_vocab;
@@ -230,6 +354,30 @@
     return str_lang + str_emo + str_event + " " + text;
 }
 
+string SenseVoiceSmall::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+    vector<int> hyps;
+    int Tmax = n_len;
+    for (int i = 0; i < Tmax; i++) {
+        int max_idx;
+        float max_val;
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
+        hyps.push_back(max_idx);
+    }
+    if(!is_stamp){
+        return online_vocab->Vector2StringV2(hyps, language);
+    }else{
+        std::vector<string> char_list;
+        std::vector<std::vector<float>> timestamp_list;
+        std::string res_str;
+        online_vocab->Vector2String(hyps, char_list);
+        std::vector<string> raw_char(char_list);
+        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
+
+        return PostProcess(raw_char, timestamp_list);
+    }
+}
+
 void SenseVoiceSmall::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
 
     std::vector<std::vector<float>> out_feats;

--
Gitblit v1.9.1