From 58c119c50882b1b2b60e1331687ea07678b15e5e Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期五, 14 四月 2023 17:39:07 +0800
Subject: [PATCH] modify paraformer onnx init
---
funasr/runtime/onnxruntime/src/paraformer_onnx.h | 12 +++---------
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp | 25 +++++++++++++------------
2 files changed, 16 insertions(+), 21 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 0d9c658..695e0f7 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
string model_path;
string cmvn_path;
string config_path;
@@ -29,20 +29,20 @@
#ifdef _WIN32
wstring wstrPath = strToWstr(model_path);
- m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
- m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+ m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
string strName;
- getInputName(m_session, strName);
+ getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
- getInputName(m_session, strName,1);
+ getInputName(m_session.get(), strName,1);
m_strInputNames.push_back(strName);
- getOutputName(m_session, strName);
+ getOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
- getOutputName(m_session, strName,1);
+ getOutputName(m_session.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@@ -55,11 +55,6 @@
ModelImp::~ModelImp()
{
- if (m_session)
- {
- delete m_session;
- m_session = nullptr;
- }
if(vocab)
delete vocab;
fftwf_free(fft_input);
@@ -172,6 +167,12 @@
apply_cmvn(in);
Ort::RunOptions run_option;
+#ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
in->buff,
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index e763be2..8946ae1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -24,15 +24,9 @@
string greedy_search( float* in, int nLen);
-#ifdef _WIN_X86
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-#else
- Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
-#endif
-
- Ort::Session* m_session = nullptr;
- Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
- Ort::SessionOptions sessionOptions = Ort::SessionOptions();
+ std::unique_ptr<Ort::Session> m_session;
+ Ort::Env env_;
+ Ort::SessionOptions sessionOptions;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
--
Gitblit v1.9.1