雾聪
2024-03-21 d4aaa84ad16c2c862ffcb5d73bf7852c8ee90d24
fix func FunASRWfstDecoderInit
2个文件已修改
8 ■■■■ 已修改文件
runtime/onnxruntime/include/model.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h
@@ -24,7 +24,11 @@
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
    virtual std::string GetLang(){return "";};
    virtual int GetAsrSampleRate() = 0;
    virtual Vocab* GetVocab(){};
    virtual Vocab* GetLmVocab(){};
    virtual PhoneSet* GetPhoneSet(){};
    std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
runtime/onnxruntime/src/funasrruntime.cpp
@@ -767,13 +767,13 @@
        funasr::WfstDecoder* mm = nullptr;
        if (asr_type == ASR_OFFLINE) {
            funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
            funasr::Model* paraformer = offline_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
        } else if (asr_type == ASR_TWO_PASS){
            funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
            funasr::Model* paraformer = tpass_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(), 
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);