From 24f73665e2d8ea8e4de2fe4f900bc539d7f7b989 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 17 四月 2023 15:49:45 +0800
Subject: [PATCH] Merge pull request #367 from alibaba-damo-academy/dev_lhn2

---
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |  101 +++++++++++++++++++++++++++++++++++++-------------
 1 files changed, 74 insertions(+), 27 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 8eb0e89..695e0f7 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,18 +4,24 @@
 using namespace paraformer;
 
 ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
-{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
-    string vocab_path;
+    string cmvn_path;
+    string config_path;
+
     if(quantize)
     {
         model_path = pathAppend(path, "model_quant.onnx");
     }else{
         model_path = pathAppend(path, "model.onnx");
     }
-    vocab_path = pathAppend(path, "vocab.txt");
+    cmvn_path = pathAppend(path, "am.mvn");
+    config_path = pathAppend(path, "config.yaml");
 
-    fe = new FeatureExtract(3);
+    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
+    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
+    memset(fft_input, 0, sizeof(float) * fft_size);
+    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
 
     //sessionOptions.SetInterOpNumThreads(1);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -23,45 +29,42 @@
 
 #ifdef _WIN32
     wstring wstrPath = strToWstr(model_path);
-    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #else
-    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
+    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
     string strName;
-    getInputName(m_session, strName);
+    getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
-    getInputName(m_session, strName,1);
+    getInputName(m_session.get(), strName,1);
     m_strInputNames.push_back(strName);
     
-    getOutputName(m_session, strName);
+    getOutputName(m_session.get(), strName);
     m_strOutputNames.push_back(strName);
-    getOutputName(m_session, strName,1);
+    getOutputName(m_session.get(), strName,1);
     m_strOutputNames.push_back(strName);
 
     for (auto& item : m_strInputNames)
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(vocab_path.c_str());
+    vocab = new Vocab(config_path.c_str());
+    load_cmvn(cmvn_path.c_str());
 }
 
 ModelImp::~ModelImp()
 {
-    if(fe)
-        delete fe;
-    if (m_session)
-    {
-        delete m_session;
-        m_session = nullptr;
-    }
     if(vocab)
         delete vocab;
+    fftwf_free(fft_input);
+    fftwf_free(fft_out);
+    fftwf_destroy_plan(plan);
+    fftwf_cleanup();
 }
 
 void ModelImp::reset()
 {
-    fe->reset();
 }
 
 void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -88,16 +91,49 @@
     din = tmp;
 }
 
+void ModelImp::load_cmvn(const char *filename)
+{
+    ifstream cmvn_stream(filename);
+    string line;
+
+    while (getline(cmvn_stream, line)) {
+        istringstream iss(line);
+        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+        if (line_item[0] == "<AddShift>") {
+            getline(cmvn_stream, line);
+            istringstream means_lines_stream(line);
+            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+            if (means_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < means_lines.size() - 1; j++) {
+                    means_list.push_back(stof(means_lines[j]));
+                }
+                continue;
+            }
+        }
+        else if (line_item[0] == "<Rescale>") {
+            getline(cmvn_stream, line);
+            istringstream vars_lines_stream(line);
+            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+            if (vars_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < vars_lines.size() - 1; j++) {
+                    vars_list.push_back(stof(vars_lines[j])*scale);
+                }
+                continue;
+            }
+        }
+    }
+}
+
 void ModelImp::apply_cmvn(Tensor<float>* din)
 {
     const float* var;
     const float* mean;
-    float scale = 22.6274169979695;
+    var = vars_list.data();
+    mean= means_list.data();
+
     int m = din->size[2];
     int n = din->size[3];
 
-    var = (const float*)paraformer_cmvn_var_hex;
-    mean = (const float*)paraformer_cmvn_mean_hex;
     for (int i = 0; i < m; i++) {
         for (int j = 0; j < n; j++) {
             int idx = i * n + j;
@@ -122,13 +158,20 @@
 
 string ModelImp::forward(float* din, int len, int flag)
 {
-
     Tensor<float>* in;
-    fe->insert(din, len, flag);
+    FeatureExtract* fe = new FeatureExtract(3);
+    fe->reset();
+    fe->insert(plan, din, len, flag);
     fe->fetch(in);
     apply_lfr(in);
     apply_cmvn(in);
     Ort::RunOptions run_option;
+
+#ifdef _WIN_X86
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
 
     std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
     Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
@@ -155,7 +198,6 @@
         auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
 
-
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -166,9 +208,14 @@
         result = "";
     }
 
-
-    if(in)
+    if(in){
         delete in;
+        in = nullptr;
+    }
+    if(fe){
+        delete fe;
+        fe = nullptr;
+    }
 
     return result;
 }

--
Gitblit v1.9.1