From b78d47f1efb3d0662fce1b8d45a9eb11b3caef02 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期三, 26 四月 2023 17:17:52 +0800
Subject: [PATCH] Merge pull request #427 from alibaba-damo-academy/dev_gflags

---
 funasr/runtime/onnxruntime/src/paraformer.cpp |   92 +++++++++++++++++++++++++++++++++-------------
 1 files changed, 66 insertions(+), 26 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 72127f8..136d228 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,72 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
 #include "precomp.h"
 
 using namespace std;
 using namespace paraformer;
 
-Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
 :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
-    string model_path;
-    string cmvn_path;
-    string config_path;
 
     // VAD model
-    if(use_vad){
-        string vad_path = PathAppend(path, "vad_model.onnx");
-        string mvn_path = PathAppend(path, "vad.mvn");
+    if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
+        use_vad = true;
+        string vad_model_path;
+        string vad_cmvn_path;
+        string vad_config_path;
+    
+        try{
+            vad_model_path = model_path.at(VAD_MODEL_PATH);
+            vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
+            vad_config_path = model_path.at(VAD_CONFIG_PATH);
+        }catch(const out_of_range& e){
+            LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what();
+            exit(0);
+        }
         vad_handle = make_unique<FsmnVad>();
-        vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+        vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path);
+    }
+
+    // AM model
+    if(model_path.find(AM_MODEL_PATH) != model_path.end()){
+        string am_model_path;
+        string am_cmvn_path;
+        string am_config_path;
+    
+        try{
+            am_model_path = model_path.at(AM_MODEL_PATH);
+            am_cmvn_path = model_path.at(AM_CMVN_PATH);
+            am_config_path = model_path.at(AM_CONFIG_PATH);
+        }catch(const out_of_range& e){
+            LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
+            exit(0);
+        }
+        InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
     }
 
     // PUNC model
-    if(use_punc){
-        punc_handle = make_unique<CTTransformer>(path, thread_num);
-    }
+    if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
+        use_punc = true;
+        string punc_model_path;
+        string punc_config_path;
+    
+        try{
+            punc_model_path = model_path.at(PUNC_MODEL_PATH);
+            punc_config_path = model_path.at(PUNC_CONFIG_PATH);
+        }catch(const out_of_range& e){
+            LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
+            exit(0);
+        }
 
-    if(quantize)
-    {
-        model_path = PathAppend(path, "model_quant.onnx");
-    }else{
-        model_path = PathAppend(path, "model.onnx");
+        punc_handle = make_unique<CTTransformer>();
+        punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
     }
-    cmvn_path = PathAppend(path, "am.mvn");
-    config_path = PathAppend(path, "config.yaml");
+}
 
+void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
     // knf options
     fbank_opts.frame_opts.dither = 0;
     fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +84,12 @@
     // DisableCpuMemArena can improve performance
     session_options.DisableCpuMemArena();
 
-#ifdef _WIN32
-    wstring wstrPath = strToWstr(model_path);
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#else
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#endif
+    try {
+        m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am onnx model: " << e.what();
+        exit(0);
+    }
 
     string strName;
     GetInputName(m_session.get(), strName);
@@ -70,8 +106,8 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(config_path.c_str());
-    LoadCmvn(cmvn_path.c_str());
+    vocab = new Vocab(am_config.c_str());
+    LoadCmvn(am_cmvn.c_str());
 }
 
 Paraformer::~Paraformer()
@@ -113,6 +149,10 @@
 void Paraformer::LoadCmvn(const char *filename)
 {
     ifstream cmvn_stream(filename);
+    if (!cmvn_stream.is_open()) {
+        LOG(ERROR) << "Failed to open file: " << filename;
+        exit(0);
+    }
     string line;
 
     while (getline(cmvn_stream, line)) {

--
Gitblit v1.9.1