From 8f26a9acc2461ce0c77eacc3d36d3cef3457f520 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 29 三月 2023 15:49:46 +0800
Subject: [PATCH] Merge branch 'dev_wjm' of https://github.com/alibaba/FunASR into dev_wjm

---
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |   61 ++++++++++++++++++++++++++----
 1 files changed, 53 insertions(+), 8 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 46b5211..a49069c 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,14 +3,25 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 {
-    string model_path = pathAppend(path, "model.onnx");
-    string vocab_path = pathAppend(path, "vocab.txt");
+    string model_path;
+    string cmvn_path;
+    string config_path;
+
+    if(quantize)
+    {
+        model_path = pathAppend(path, "model_quant.onnx");
+    }else{
+        model_path = pathAppend(path, "model.onnx");
+    }
+    cmvn_path = pathAppend(path, "am.mvn");
+    config_path = pathAppend(path, "config.yaml");
 
     fe = new FeatureExtract(3);
 
-    sessionOptions.SetInterOpNumThreads(nNumThread);
+    //sessionOptions.SetInterOpNumThreads(1);
+    sessionOptions.SetIntraOpNumThreads(nNumThread);
     sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
 
 #ifdef _WIN32
@@ -35,7 +46,8 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(vocab_path.c_str());
+    vocab = new Vocab(config_path.c_str());
+    load_cmvn(cmvn_path.c_str());
 }
 
 ModelImp::~ModelImp()
@@ -80,16 +92,49 @@
     din = tmp;
 }
 
+void ModelImp::load_cmvn(const char *filename)
+{
+    ifstream cmvn_stream(filename);
+    string line;
+
+    while (getline(cmvn_stream, line)) {
+        istringstream iss(line);
+        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+        if (line_item[0] == "<AddShift>") {
+            getline(cmvn_stream, line);
+            istringstream means_lines_stream(line);
+            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+            if (means_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < means_lines.size() - 1; j++) {
+                    means_list.push_back(stof(means_lines[j]));
+                }
+                continue;
+            }
+        }
+        else if (line_item[0] == "<Rescale>") {
+            getline(cmvn_stream, line);
+            istringstream vars_lines_stream(line);
+            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+            if (vars_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < vars_lines.size() - 1; j++) {
+                    vars_list.push_back(stof(vars_lines[j])*scale);
+                }
+                continue;
+            }
+        }
+    }
+}
+
 void ModelImp::apply_cmvn(Tensor<float>* din)
 {
     const float* var;
     const float* mean;
-    float scale = 22.6274169979695;
+    var = vars_list.data();
+    mean= means_list.data();
+
     int m = din->size[2];
     int n = din->size[3];
 
-    var = (const float*)paraformer_cmvn_var_hex;
-    mean = (const float*)paraformer_cmvn_mean_hex;
     for (int i = 0; i < m; i++) {
         for (int j = 0; j < n; j++) {
             int idx = i * n + j;

--
Gitblit v1.9.1