From d4aaa84ad16c2c862ffcb5d73bf7852c8ee90d24 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 21 三月 2024 14:17:22 +0800
Subject: [PATCH] fix func FunASRWfstDecoderInit

---
 runtime/onnxruntime/src/funasrruntime.cpp |    4 ++--
 runtime/onnxruntime/include/model.h       |    4 ++++
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 33caec8..b72db92 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -24,7 +24,11 @@
     virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
     virtual std::string GetLang(){return "";};
     virtual int GetAsrSampleRate() = 0;
+    virtual Vocab* GetVocab(){};
+    virtual Vocab* GetLmVocab(){};
+    virtual PhoneSet* GetPhoneSet(){};
 
+    std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
 };
 
 Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index d795cb0..b283772 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -767,13 +767,13 @@
 		funasr::WfstDecoder* mm = nullptr;
 		if (asr_type == ASR_OFFLINE) {
 			funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
-			funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
+			funasr::Model* paraformer = offline_stream->asr_handle.get();
 			if (paraformer->lm_)
 				mm = new funasr::WfstDecoder(paraformer->lm_.get(),
 					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
 		} else if (asr_type == ASR_TWO_PASS){
 			funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
-			funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
+			funasr::Model* paraformer = tpass_stream->asr_handle.get();
 			if (paraformer->lm_)
 				mm = new funasr::WfstDecoder(paraformer->lm_.get(), 
 					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);

--
Gitblit v1.9.1