雾聪
2024-03-21 6bbbf14c0080df7888aa7c54b48d6dadf27867ba
runtime/onnxruntime/src/funasrruntime.cpp
@@ -767,16 +767,45 @@
      funasr::WfstDecoder* mm = nullptr;
      if (asr_type == ASR_OFFLINE) {
         funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
         funasr::Model* paraformer = offline_stream->asr_handle.get();
         if (paraformer->lm_)
            mm = new funasr::WfstDecoder(paraformer->lm_.get(),
               paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
         auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
         if(paraformer !=nullptr){
            if (paraformer->lm_){
               mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                  paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            }
            return mm;
         }
         #ifdef USE_GPU
         auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
         if(paraformer_torch !=nullptr){
            if (paraformer_torch->lm_){
               mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                  paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
            }
            return mm;
         }
         #endif
      } else if (asr_type == ASR_TWO_PASS){
         funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
         funasr::Model* paraformer = tpass_stream->asr_handle.get();
         if (paraformer->lm_)
            mm = new funasr::WfstDecoder(paraformer->lm_.get(),
               paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
         auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
         if(paraformer !=nullptr){
            if (paraformer->lm_){
               mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                  paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            }
            return mm;
         }
         #ifdef USE_GPU
         auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
         if(paraformer_torch !=nullptr){
            if (paraformer_torch->lm_){
               mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                  paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
            }
            return mm;
         }
         #endif
      }
      return mm;
   }