雾聪
2024-10-29 1819303f5e8cfc03f4c0ec2495571a54a186d34b
support SenseVoiceSmall in 2pass mode
16个文件已修改
478 ■■■■ 已修改文件
runtime/onnxruntime/include/funasrruntime.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/tpass-stream.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-online.cpp 106 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-online.h 23 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.cpp 27 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/sensevoice-small.cpp 166 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/sensevoice-small.h 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/tpass-online-stream.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/tpass-stream.cpp 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/funasr-wss-client-2pass.cpp 33 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/funasr-wss-server-2pass.cpp 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server-2pass.cpp 26 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server-2pass.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/funasrruntime.h
@@ -120,7 +120,8 @@
_FUNASRAPI FUNASR_RESULT    FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
                                                int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, 
                                                int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, 
                                                const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
                                                const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr,
                                                std::string svs_lang="auto", bool svs_itn=true);
_FUNASRAPI void                FunTpassUninit(FUNASR_HANDLE handle);
_FUNASRAPI void                FunTpassOnlineUninit(FUNASR_HANDLE handle);
runtime/onnxruntime/include/model.h
@@ -16,9 +16,11 @@
    virtual void StartUtterance() = 0;
    virtual void EndUtterance() = 0;
    virtual void Reset() = 0;
    virtual string GreedySearch(float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0}){return "";};
    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn,
      const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){};
    virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
    virtual void InitFstDecoder(){};
    virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
runtime/onnxruntime/include/tpass-stream.h
@@ -26,11 +26,13 @@
    bool UseVad(){return use_vad;};
    bool UsePunc(){return use_punc;}; 
    bool UseITN(){return use_itn;};
    std::string GetModelType(){return model_type;};
    
  private:
    bool use_vad=false;
    bool use_punc=false;
    bool use_itn=false;
    std::string model_type = MODEL_PARA;
};
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
runtime/onnxruntime/src/funasrruntime.cpp
@@ -482,7 +482,8 @@
    _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
                                                 int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, 
                                                 int sampling_rate, std::string wav_format, ASR_TYPE mode, 
                                                 const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
                                                 const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle,
                                                 std::string svs_lang, bool svs_itn)
    {
        funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
        funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -525,7 +526,7 @@
        funasr::AudioFrame* frame = nullptr;
        while(audio->FetchChunck(frame) > 0){
            string msg = ((funasr::ParaformerOnline*)asr_online_handle)->Forward(frame->data, frame->len, frame->is_final);
            string msg = (asr_online_handle)->Forward(frame->data, frame->len, frame->is_final);
            if(mode == ASR_ONLINE){
                ((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
                if(frame->is_final){
@@ -567,7 +568,12 @@
            len = new int[1];
            buff[0] = frame->data;
            len[0] = frame->len;
            vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
            vector<string> msgs;
            if(tpass_stream->GetModelType() == MODEL_SVS){
                msgs = (tpass_stream->asr_handle)->Forward(buff, len, true, svs_lang, svs_itn, 1);
            }else{
                msgs = (tpass_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, 1);
            }
            string msg = msgs.size()>0?msgs[0]:"";
            std::vector<std::string> msg_vec = funasr::SplitStr(msg, " | ");  // split with timestamp
            if(msg_vec.size()==0){
@@ -589,24 +595,29 @@
                p_result->stamp += cur_stamp + "]";
            }
            string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
            if(input_finished){
                msg_punc += "。";
            }
            p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
            if(tpass_stream->UseITN() && itn){
                string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
                // TimestampSmooth
                if(!(p_result->stamp).empty()){
                    std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
                    if(!new_stamp.empty()){
                        p_result->stamp = new_stamp;
                    }
            if (tpass_stream->GetModelType() == MODEL_PARA){
                string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
                if(input_finished){
                    msg_punc += "。";
                }
                p_result->tpass_msg = msg_itn;
            }
                p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
                if(tpass_stream->UseITN() && itn){
                    string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
                    // TimestampSmooth
                    if(!(p_result->stamp).empty()){
                        std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
                        if(!new_stamp.empty()){
                            p_result->stamp = new_stamp;
                        }
                    }
                    p_result->tpass_msg = msg_itn;
                }
#endif
            }else{
                p_result->tpass_msg = msg;
            }
            if (!(p_result->stamp).empty()){
                p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
            }
runtime/onnxruntime/src/paraformer-online.cpp
@@ -9,18 +9,55 @@
namespace funasr {
ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
    InitOnline(
        para_handle_->fbank_opts_,
        para_handle_->encoder_session_,
        para_handle_->decoder_session_,
        para_handle_->en_szInputNames_,
        para_handle_->en_szOutputNames_,
        para_handle_->de_szInputNames_,
        para_handle_->de_szOutputNames_,
        para_handle_->means_list_,
        para_handle_->vars_list_);
ParaformerOnline::ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type)
:offline_handle_(std::move(offline_handle)),chunk_size(chunk_size),session_options_{}{
    if(model_type == MODEL_PARA){
        Paraformer* para_handle = dynamic_cast<Paraformer*>(offline_handle_);
        InitOnline(
        para_handle->fbank_opts_,
        para_handle->encoder_session_,
        para_handle->decoder_session_,
        para_handle->en_szInputNames_,
        para_handle->en_szOutputNames_,
        para_handle->de_szInputNames_,
        para_handle->de_szOutputNames_,
        para_handle->means_list_,
        para_handle->vars_list_,
        para_handle->frame_length,
        para_handle->frame_shift,
        para_handle->n_mels,
        para_handle->lfr_m,
        para_handle->lfr_n,
        para_handle->encoder_size,
        para_handle->fsmn_layers,
        para_handle->fsmn_lorder,
        para_handle->fsmn_dims,
        para_handle->cif_threshold,
        para_handle->tail_alphas);
    }else if(model_type == MODEL_SVS){
        SenseVoiceSmall* svs_handle = dynamic_cast<SenseVoiceSmall*>(offline_handle_);
        InitOnline(
        svs_handle->fbank_opts_,
        svs_handle->encoder_session_,
        svs_handle->decoder_session_,
        svs_handle->en_szInputNames_,
        svs_handle->en_szOutputNames_,
        svs_handle->de_szInputNames_,
        svs_handle->de_szOutputNames_,
        svs_handle->means_list_,
        svs_handle->vars_list_,
        svs_handle->frame_length,
        svs_handle->frame_shift,
        svs_handle->n_mels,
        svs_handle->lfr_m,
        svs_handle->lfr_n,
        svs_handle->encoder_size,
        svs_handle->fsmn_layers,
        svs_handle->fsmn_lorder,
        svs_handle->fsmn_dims,
        svs_handle->cif_threshold,
        svs_handle->tail_alphas);
    }
    InitCache();
}
@@ -33,7 +70,18 @@
        vector<const char*> &de_szInputNames,
        vector<const char*> &de_szOutputNames,
        vector<float> &means_list,
        vector<float> &vars_list){
        vector<float> &vars_list,
        int frame_length_,
        int frame_shift_,
        int n_mels_,
        int lfr_m_,
        int lfr_n_,
        int encoder_size_,
        int fsmn_layers_,
        int fsmn_lorder_,
        int fsmn_dims_,
        float cif_threshold_,
        float tail_alphas_){
    fbank_opts_ = fbank_opts;
    encoder_session_ = encoder_session;
    decoder_session_ = decoder_session;
@@ -44,27 +92,27 @@
    means_list_ = means_list;
    vars_list_ = vars_list;
    frame_length = para_handle_->frame_length;
    frame_shift = para_handle_->frame_shift;
    n_mels = para_handle_->n_mels;
    lfr_m = para_handle_->lfr_m;
    lfr_n = para_handle_->lfr_n;
    encoder_size = para_handle_->encoder_size;
    fsmn_layers = para_handle_->fsmn_layers;
    fsmn_lorder = para_handle_->fsmn_lorder;
    fsmn_dims = para_handle_->fsmn_dims;
    cif_threshold = para_handle_->cif_threshold;
    tail_alphas = para_handle_->tail_alphas;
    frame_length = frame_length_;
    frame_shift = frame_shift_;
    n_mels = n_mels_;
    lfr_m = lfr_m_;
    lfr_n = lfr_n_;
    encoder_size = encoder_size_;
    fsmn_layers = fsmn_layers_;
    fsmn_lorder = fsmn_lorder_;
    fsmn_dims = fsmn_dims_;
    cif_threshold = cif_threshold_;
    tail_alphas = tail_alphas_;
    // other vars
    sqrt_factor = std::sqrt(encoder_size);
    for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
        fsmn_init_cache_.emplace_back(0);
    }
    chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
    chunk_len = chunk_size[1]*frame_shift*lfr_n*offline_handle_->GetAsrSampleRate()/1000;
    frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
    frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
    frame_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_length;
    frame_shift_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_shift;
}
@@ -464,7 +512,7 @@
            std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
            float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
            result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
            result = offline_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
        }
    }catch (std::exception const &e)
    {
@@ -493,7 +541,7 @@
        if(is_first_chunk){
            is_first_chunk = false;
        }
        ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
        ExtractFeats(offline_handle_->GetAsrSampleRate(), wav_feats, waves, input_finished);
        if(wav_feats.size() == 0){
            return result;
        }
runtime/onnxruntime/src/paraformer-online.h
@@ -38,7 +38,18 @@
            vector<const char*> &de_szInputNames,
            vector<const char*> &de_szOutputNames,
            vector<float> &means_list,
            vector<float> &vars_list);
            vector<float> &vars_list,
            int frame_length_,
            int frame_shift_,
            int n_mels_,
            int lfr_m_,
            int lfr_n_,
            int encoder_size_,
            int fsmn_layers_,
            int fsmn_lorder_,
            int fsmn_dims_,
            float cif_threshold_,
            float tail_alphas_);
        void StartUtterance()
        {
@@ -48,8 +59,8 @@
        {
        }
        
        Paraformer* para_handle_ = nullptr;
        // from para_handle_
        Model* offline_handle_ = nullptr;
        // from offline_handle_
        knf::FbankOptions fbank_opts_;
        std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
        std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
@@ -60,7 +71,7 @@
        vector<const char*> de_szOutputNames_;
        vector<float> means_list_;
        vector<float> vars_list_;
        // configs from para_handle_
        // configs from offline_handle_
        int frame_length = 25;
        int frame_shift = 10;
        int n_mels = 80;
@@ -100,7 +111,7 @@
        double sqrt_factor;
    public:
        ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
        ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type=MODEL_PARA);
        ~ParaformerOnline();
        void Reset();
        void ResetCache();
@@ -112,7 +123,7 @@
        string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
        string Rescoring();
        int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
        int GetAsrSampleRate() { return offline_handle_->GetAsrSampleRate(); };
        // 2pass
        std::string online_res;
runtime/onnxruntime/src/paraformer.cpp
@@ -131,9 +131,10 @@
}
// 2pass
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model,
    const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){
    // online
    InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
    InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num);
    // offline
    try {
@@ -144,28 +145,6 @@
        exit(-1);
    }
    // string strName;
    // GetInputName(m_session_.get(), strName);
    // m_strInputNames.push_back(strName.c_str());
    // GetInputName(m_session_.get(), strName,1);
    // m_strInputNames.push_back(strName);
    // if (use_hotword) {
    //     GetInputName(m_session_.get(), strName, 2);
    //     m_strInputNames.push_back(strName);
    // }
    // // support time stamp
    // size_t numOutputNodes = m_session_->GetOutputCount();
    // for(int index=0; index<numOutputNodes; index++){
    //     GetOutputName(m_session_.get(), strName, index);
    //     m_strOutputNames.push_back(strName);
    // }
    // for (auto& item : m_strInputNames)
    //     m_szInputNames.push_back(item.c_str());
    // for (auto& item : m_strOutputNames)
    //     m_szOutputNames.push_back(item.c_str());
    GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
    GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
}
runtime/onnxruntime/src/paraformer.h
@@ -46,7 +46,8 @@
        // online
        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        // 2pass
        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn,
            const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num);
        void InitHwCompiler(const std::string &hw_model, int thread_num);
        void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
runtime/onnxruntime/src/sensevoice-small.cpp
@@ -48,6 +48,145 @@
    LoadCmvn(am_cmvn.c_str());
}
// online
void SenseVoiceSmall::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
    LoadOnlineConfigFromYaml(am_config.c_str());
    // knf options
    fbank_opts_.frame_opts.dither = 0;
    fbank_opts_.mel_opts.num_bins = n_mels;
    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
    fbank_opts_.frame_opts.window_type = window_type;
    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
    fbank_opts_.frame_opts.frame_length_ms = frame_length;
    fbank_opts_.energy_floor = 0;
    fbank_opts_.mel_opts.debug_mel = false;
    // session_options_.SetInterOpNumThreads(1);
    session_options_.SetIntraOpNumThreads(thread_num);
    session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    // DisableCpuMemArena can improve performance
    session_options_.DisableCpuMemArena();
    try {
        encoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(en_model).c_str(), session_options_);
        LOG(INFO) << "Successfully load model from " << en_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am encoder model: " << e.what();
        exit(-1);
    }
    try {
        decoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(de_model).c_str(), session_options_);
        LOG(INFO) << "Successfully load model from " << de_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am decoder model: " << e.what();
        exit(-1);
    }
    // encoder
    string strName;
    GetInputName(encoder_session_.get(), strName);
    en_strInputNames.push_back(strName.c_str());
    GetInputName(encoder_session_.get(), strName,1);
    en_strInputNames.push_back(strName);
    GetOutputName(encoder_session_.get(), strName);
    en_strOutputNames.push_back(strName);
    GetOutputName(encoder_session_.get(), strName,1);
    en_strOutputNames.push_back(strName);
    GetOutputName(encoder_session_.get(), strName,2);
    en_strOutputNames.push_back(strName);
    for (auto& item : en_strInputNames)
        en_szInputNames_.push_back(item.c_str());
    for (auto& item : en_strOutputNames)
        en_szOutputNames_.push_back(item.c_str());
    // decoder
    int de_input_len = 4 + fsmn_layers;
    int de_out_len = 2 + fsmn_layers;
    for(int i=0;i<de_input_len; i++){
        GetInputName(decoder_session_.get(), strName, i);
        de_strInputNames.push_back(strName.c_str());
    }
    for(int i=0;i<de_out_len; i++){
        GetOutputName(decoder_session_.get(), strName,i);
        de_strOutputNames.push_back(strName);
    }
    for (auto& item : de_strInputNames)
        de_szInputNames_.push_back(item.c_str());
    for (auto& item : de_strOutputNames)
        de_szOutputNames_.push_back(item.c_str());
    online_vocab = new Vocab(token_file.c_str());
    phone_set_ = new PhoneSet(token_file.c_str());
    LoadCmvn(am_cmvn.c_str());
}
// 2pass
void SenseVoiceSmall::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model,
    const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){
    // online
    InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num);
    // offline
    try {
        m_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(am_model).c_str(), session_options_);
        LOG(INFO) << "Successfully load model from " << am_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am onnx model: " << e.what();
        exit(-1);
    }
    GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
    GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
    vocab = new Vocab(token_file.c_str());
}
void SenseVoiceSmall::LoadOnlineConfigFromYaml(const char* filename){
    YAML::Node config;
    try{
        config = YAML::LoadFile(filename);
    }catch(exception const &e){
        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
        exit(-1);
    }
    try{
        YAML::Node frontend_conf = config["frontend_conf"];
        YAML::Node encoder_conf = config["encoder_conf"];
        YAML::Node decoder_conf = config["decoder_conf"];
        YAML::Node predictor_conf = config["predictor_conf"];
        this->window_type = frontend_conf["window"].as<string>();
        this->n_mels = frontend_conf["n_mels"].as<int>();
        this->frame_length = frontend_conf["frame_length"].as<int>();
        this->frame_shift = frontend_conf["frame_shift"].as<int>();
        this->lfr_m = frontend_conf["lfr_m"].as<int>();
        this->lfr_n = frontend_conf["lfr_n"].as<int>();
        this->encoder_size = encoder_conf["output_size"].as<int>();
        this->fsmn_dims = encoder_conf["output_size"].as<int>();
        this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
        this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
        this->cif_threshold = predictor_conf["threshold"].as<double>();
        this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
        this->asr_sample_rate = frontend_conf["fs"].as<int>();
    }catch(exception const &e){
        LOG(ERROR) << "Error when load argument from vad config YAML.";
        exit(-1);
    }
}
void SenseVoiceSmall::LoadConfigFromYaml(const char* filename){
    YAML::Node config;
@@ -83,6 +222,9 @@
{
    if(vocab){
        delete vocab;
    }
    if(online_vocab){
        delete online_vocab;
    }
    if(lm_vocab){
        delete lm_vocab;
@@ -212,6 +354,30 @@
    return str_lang + str_emo + str_event + " " + text;
}
string SenseVoiceSmall::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
    if(!is_stamp){
        return online_vocab->Vector2StringV2(hyps, language);
    }else{
        std::vector<string> char_list;
        std::vector<std::vector<float>> timestamp_list;
        std::string res_str;
        online_vocab->Vector2String(hyps, char_list);
        std::vector<string> raw_char(char_list);
        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
        return PostProcess(raw_char, timestamp_list);
    }
}
void SenseVoiceSmall::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
    std::vector<std::vector<float>> out_feats;
runtime/onnxruntime/src/sensevoice-small.h
@@ -12,12 +12,14 @@
    class SenseVoiceSmall : public Model {
    private:
        Vocab* vocab = nullptr;
        Vocab* online_vocab = nullptr;
        Vocab* lm_vocab = nullptr;
        SegDict* seg_dict = nullptr;
        PhoneSet* phone_set_ = nullptr;
        const float scale = 1.0;
        void LoadConfigFromYaml(const char* filename);
        void LoadOnlineConfigFromYaml(const char* filename);
        void LoadCmvn(const char *filename);
        void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
@@ -34,9 +36,10 @@
        ~SenseVoiceSmall();
        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        // online
        // void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        // 2pass
        // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config,
            const std::string &token_file, const std::string &online_token_file, int thread_num);
        // void InitHwCompiler(const std::string &hw_model, int thread_num);
        // void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
@@ -44,7 +47,8 @@
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, std::string svs_lang="auto", bool svs_itn=true, int batch_in=1);
        string CTCSearch( float * in, std::vector<int32_t> paraformer_length, std::vector<int64_t> outputShape);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
@@ -100,6 +104,8 @@
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
        int blank_id = 0;
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        //dict
        std::map<std::string, int> lid_map = {
            {"auto", 0},
runtime/onnxruntime/src/tpass-online-stream.cpp
@@ -11,7 +11,7 @@
    }
    if(tpass_obj->asr_handle){
        asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
        asr_online_handle = make_unique<ParaformerOnline>((tpass_obj->asr_handle).get(), chunk_size, tpass_stream->GetModelType());
    }else{
        LOG(ERROR)<<"asr_handle is null";
        exit(-1);
runtime/onnxruntime/src/tpass-stream.cpp
@@ -36,10 +36,17 @@
        string am_cmvn_path;
        string am_config_path;
        string token_path;
        string online_token_path;
        string hw_compile_model_path;
        string seg_dict_path;
        
        asr_handle = make_unique<Paraformer>();
        if (model_path.at(MODEL_DIR).find(MODEL_SVS) != std::string::npos)
        {
            asr_handle = make_unique<SenseVoiceSmall>();
            model_type = MODEL_SVS;
        }else{
            asr_handle = make_unique<Paraformer>();
        }
        bool enable_hotword = false;
        hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
@@ -54,6 +61,7 @@
        am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
        en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
        de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
        online_token_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), TOKEN_PATH);
        if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
            en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
@@ -63,7 +71,7 @@
        am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
        token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
        asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
        asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, online_token_path, thread_num);
    }else{
        LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
        exit(-1);
runtime/websocket/bin/funasr-wss-client-2pass.cpp
@@ -124,7 +124,7 @@
  void run(const std::string& uri, const std::vector<string>& wav_list,
           const std::vector<string>& wav_ids, int audio_fs, std::string asr_mode,
           std::vector<int> chunk_size, const std::unordered_map<std::string, int>& hws_map,
           bool is_record=false, int use_itn=1) {
           bool is_record=false, int use_itn=1, int svs_itn=1) {
    // Create a new connection to the given URI
    websocketpp::lib::error_code ec;
    typename websocketpp::client<T>::connection_ptr con =
@@ -146,9 +146,9 @@
    websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                         &m_client);
    if(is_record){
      send_rec_data(asr_mode, chunk_size, hws_map, use_itn);
      send_rec_data(asr_mode, chunk_size, hws_map, use_itn, svs_itn);
    }else{
      send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn);
      send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn, svs_itn);
    }
    WaitABit();
@@ -185,7 +185,7 @@
  // send wav to server
  void send_wav_data(string wav_path, string wav_id, int audio_fs, std::string asr_mode,
                     std::vector<int> chunk_vector, const std::unordered_map<std::string, int>& hws_map,
                     int use_itn) {
                     int use_itn, int svs_itn) {
    uint64_t count = 0;
    std::stringstream val;
@@ -241,8 +241,12 @@
    jsonbegin["audio_fs"] = sampling_rate;
    jsonbegin["is_speaking"] = true;
    jsonbegin["itn"] = true;
    jsonbegin["svs_itn"] = true;
    if(use_itn == 0){
      jsonbegin["itn"] = false;
    }
    if(svs_itn == 0){
        jsonbegin["svs_itn"] = false;
    }
    if(!hws_map.empty()){
        LOG(INFO) << "hotwords: ";
@@ -335,7 +339,7 @@
  }
  void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector, 
                     const std::unordered_map<std::string, int>& hws_map, int use_itn) {
                     const std::unordered_map<std::string, int>& hws_map, int use_itn, int svs_itn) {
    // first message
    bool wait = false;
    while (1) {
@@ -374,8 +378,12 @@
    jsonbegin["audio_fs"] = sample_rate;
    jsonbegin["is_speaking"] = true;
    jsonbegin["itn"] = true;
    jsonbegin["svs_itn"] = true;
    if(use_itn == 0){
      jsonbegin["itn"] = false;
    }
    if(svs_itn == 0){
        jsonbegin["svs_itn"] = false;
    }
    if(!hws_map.empty()){
        LOG(INFO) << "hotwords: ";
@@ -513,6 +521,9 @@
      "", "use-itn",
      "use-itn is 1 means use itn, 0 means not use itn", false, 1,
      "int");
  TCLAP::ValueArg<int> svs_itn_(
      "", "svs-itn",
      "svs-itn is 1 means use itn and punc, 0 means not use", false, 1, "int");
  TCLAP::ValueArg<std::string> hotword_("", HOTWORD,
      "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
@@ -526,6 +537,7 @@
  cmd.add(thread_num_);
  cmd.add(is_ssl_);
  cmd.add(use_itn_);
  cmd.add(svs_itn_);
  cmd.add(hotword_);
  cmd.parse(argc, argv);
@@ -535,6 +547,7 @@
  std::string asr_mode = asr_mode_.getValue();
  std::string chunk_size_str = chunk_size_.getValue();
  int use_itn = use_itn_.getValue();
  int svs_itn = svs_itn_.getValue();
  // get chunk_size
  std::vector<int> chunk_size;
  std::stringstream ss(chunk_size_str);
@@ -577,11 +590,11 @@
        c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn, svs_itn);
      } else {
        WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn, svs_itn);
      }
  }else{
@@ -622,17 +635,17 @@
        tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
        client_threads.emplace_back(
            [uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
            [uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn, svs_itn]() {
              if (is_ssl == 1) {
                WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
                c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn, svs_itn);
              } else {
                WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn, svs_itn);
              }
            });
      }
runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -276,6 +276,12 @@
            s_itn_path="";
            s_lm_path="";
        }
        found = s_offline_asr_path.find(MODEL_SVS);
        if (found != std::string::npos) {
            model_path["model-revision"]="v2.0.5";
            s_lm_path="";
            model_path[LM_DIR]="";
        }
        if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
          // local
runtime/websocket/bin/websocket-server-2pass.cpp
@@ -111,7 +111,9 @@
    int audio_fs,
    std::string wav_format,
    FUNASR_HANDLE& tpass_online_handle,
    FUNASR_DEC_HANDLE& decoder_handle) {
    FUNASR_DEC_HANDLE& decoder_handle,
    std::string svs_lang,
    bool sys_itn) {
  // lock for each connection
  if(!tpass_online_handle){
    scoped_lock guard(thread_lock);
@@ -140,7 +142,8 @@
                                       subvector.data(), subvector.size(),
                                       punc_cache, false, audio_fs,
                                       wav_format, (ASR_TYPE)asr_mode_,
                                       hotwords_embedding, itn, decoder_handle);
                                       hotwords_embedding, itn, decoder_handle,
                                       svs_lang, sys_itn);
        } else {
          scoped_lock guard(thread_lock);
@@ -177,7 +180,8 @@
                                       buffer.data(), buffer.size(), punc_cache,
                                       is_final, audio_fs,
                                       wav_format, (ASR_TYPE)asr_mode_,
                                       hotwords_embedding, itn, decoder_handle);
                                       hotwords_embedding, itn, decoder_handle,
                                       svs_lang, sys_itn);
        } else {
          scoped_lock guard(thread_lock);
          msg["access_num"]=(int)msg["access_num"]-1;     
@@ -250,6 +254,8 @@
    data_msg->msg["audio_fs"] = 16000; // default is 16k
    data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
    data_msg->msg["is_eof"]=false; // if this connection is closed
    data_msg->msg["svs_lang"]="auto";
    data_msg->msg["svs_itn"]=true;
    FUNASR_DEC_HANDLE decoder_handle =
      FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
    data_msg->decoder_handle = decoder_handle;
@@ -475,6 +481,12 @@
      if (jsonresult.contains("itn")) {
        msg_data->msg["itn"] = jsonresult["itn"];
      }
      if (jsonresult.contains("svs_lang")) {
        msg_data->msg["svs_lang"] = jsonresult["svs_lang"];
      }
      if (jsonresult.contains("svs_itn")) {
        msg_data->msg["svs_itn"] = jsonresult["svs_itn"];
      }
      LOG(INFO) << "jsonresult=" << jsonresult
                << ", msg_data->msg=" << msg_data->msg;
      if ((jsonresult["is_speaking"] == false ||
@@ -499,7 +511,9 @@
                        msg_data->msg["audio_fs"],
                        msg_data->msg["wav_format"],
                        std::ref(msg_data->tpass_online_handle),
                        std::ref(msg_data->decoder_handle)));
                        std::ref(msg_data->decoder_handle),
                        msg_data->msg["svs_lang"],
                        msg_data->msg["svs_itn"]));
              msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
        }
        catch (std::exception const &e)
@@ -547,7 +561,9 @@
                                  msg_data->msg["audio_fs"],
                                  msg_data->msg["wav_format"],
                                  std::ref(msg_data->tpass_online_handle),
                                  std::ref(msg_data->decoder_handle)));
                                  std::ref(msg_data->decoder_handle),
                                  msg_data->msg["svs_lang"],
                                  msg_data->msg["svs_itn"]));
              msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
            }
          }
runtime/websocket/bin/websocket-server-2pass.h
@@ -125,7 +125,9 @@
                  int audio_fs,
                  std::string wav_format,
                  FUNASR_HANDLE& tpass_online_handle,
                  FUNASR_DEC_HANDLE& decoder_handle);
                  FUNASR_DEC_HANDLE& decoder_handle,
                  std::string svs_lang,
                  bool sys_itn);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);