雾聪
2024-03-21 462355c002131c105b29b8821f67f97c532b6808
fix func FunASRWfstDecoderInit
2个文件已修改
47 ■■■■ 已修改文件
runtime/onnxruntime/include/model.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 45 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h
@@ -31,8 +31,6 @@
    virtual Vocab* GetVocab() {return nullptr;};
    virtual Vocab* GetLmVocab() {return nullptr;};
    virtual PhoneSet* GetPhoneSet() {return nullptr;};
    std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
runtime/onnxruntime/src/funasrruntime.cpp
@@ -767,16 +767,45 @@
        funasr::WfstDecoder* mm = nullptr;
        if (asr_type == ASR_OFFLINE) {
            funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
            funasr::Model* paraformer = offline_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        } else if (asr_type == ASR_TWO_PASS){
            funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
            funasr::Model* paraformer = tpass_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        }
        return mm;
    }