From 331d57253ae25dd42c8e14930dee30cd8d2affa6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 四月 2023 11:41:18 +0800
Subject: [PATCH] Merge pull request #408 from alibaba-damo-academy/hnluo-patch-1

---
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |   40 ++++++++++++++++++++++------------------
 1 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index bb00849..695e0f7 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
 using namespace paraformer;
 
 ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
     string cmvn_path;
     string config_path;
@@ -18,7 +18,10 @@
     cmvn_path = pathAppend(path, "am.mvn");
     config_path = pathAppend(path, "config.yaml");
 
-    //fe = new FeatureExtract(3);
+    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+    memset(fft_input, 0, sizeof(float) * fft_size);
+    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
 
     //sessionOptions.SetInterOpNumThreads(1);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -26,20 +29,20 @@
 
 #ifdef _WIN32
     wstring wstrPath = strToWstr(model_path);
-    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #else
-    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
     string strName;
-    getInputName(m_session, strName);
+    getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
-    getInputName(m_session, strName,1);
+    getInputName(m_session.get(), strName,1);
     m_strInputNames.push_back(strName);
     
-    getOutputName(m_session, strName);
+    getOutputName(m_session.get(), strName);
     m_strOutputNames.push_back(strName);
-    getOutputName(m_session, strName,1);
+    getOutputName(m_session.get(), strName,1);
     m_strOutputNames.push_back(strName);
 
     for (auto& item : m_strInputNames)
@@ -52,21 +55,16 @@
 
 ModelImp::~ModelImp()
 {
-    //if(fe)
-    //    delete fe;
-    if (m_session)
-    {
-        delete m_session;
-        m_session = nullptr;
-    }
     if(vocab)
         delete vocab;
+    fftwf_free(fft_input);
+    fftwf_free(fft_out);
+    fftwf_destroy_plan(plan);
+    fftwf_cleanup();
 }
 
 void ModelImp::reset()
 {
-    //fe->reset();
-    printf("Not Imp!!!!!!\n");
 }
 
 void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -163,12 +161,18 @@
     Tensor<float>* in;
     FeatureExtract* fe = new FeatureExtract(3);
     fe->reset();
-    fe->insert(din, len, flag);
+    fe->insert(plan, din, len, flag);
     fe->fetch(in);
     apply_lfr(in);
     apply_cmvn(in);
     Ort::RunOptions run_option;
 
+#ifdef _WIN_X86
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+
     std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
     Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
         in->buff,

--
Gitblit v1.9.1