From b78d47f1efb3d0662fce1b8d45a9eb11b3caef02 Mon Sep 17 00:00:00 2001
From: Lizerui9926 <110582652+Lizerui9926@users.noreply.github.com>
Date: 星期三, 26 四月 2023 17:17:52 +0800
Subject: [PATCH] Merge pull request #427 from alibaba-damo-academy/dev_gflags
---
funasr/runtime/onnxruntime/src/paraformer.cpp | 92 +++++++++++++++++++++++++++++++++-------------
1 files changed, 66 insertions(+), 26 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 72127f8..136d228 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,72 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#include "precomp.h"
using namespace std;
using namespace paraformer;
-Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
- string model_path;
- string cmvn_path;
- string config_path;
// VAD model
- if(use_vad){
- string vad_path = PathAppend(path, "vad_model.onnx");
- string mvn_path = PathAppend(path, "vad.mvn");
+ if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
+ use_vad = true;
+ string vad_model_path;
+ string vad_cmvn_path;
+ string vad_config_path;
+
+ try{
+ vad_model_path = model_path.at(VAD_MODEL_PATH);
+ vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
+ vad_config_path = model_path.at(VAD_CONFIG_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what();
+ exit(0);
+ }
vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path);
+ }
+
+ // AM model
+ if(model_path.find(AM_MODEL_PATH) != model_path.end()){
+ string am_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ try{
+ am_model_path = model_path.at(AM_MODEL_PATH);
+ am_cmvn_path = model_path.at(AM_CMVN_PATH);
+ am_config_path = model_path.at(AM_CONFIG_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
+ exit(0);
+ }
+ InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
}
// PUNC model
- if(use_punc){
- punc_handle = make_unique<CTTransformer>(path, thread_num);
- }
+ if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
+ use_punc = true;
+ string punc_model_path;
+ string punc_config_path;
+
+ try{
+ punc_model_path = model_path.at(PUNC_MODEL_PATH);
+ punc_config_path = model_path.at(PUNC_CONFIG_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
+ exit(0);
+ }
- if(quantize)
- {
- model_path = PathAppend(path, "model_quant.onnx");
- }else{
- model_path = PathAppend(path, "model.onnx");
+ punc_handle = make_unique<CTTransformer>();
+ punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
}
- cmvn_path = PathAppend(path, "am.mvn");
- config_path = PathAppend(path, "config.yaml");
+}
+void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +84,12 @@
// DisableCpuMemArena can improve performance
session_options.DisableCpuMemArena();
-#ifdef _WIN32
- wstring wstrPath = strToWstr(model_path);
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#else
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#endif
+ try {
+ m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am onnx model: " << e.what();
+ exit(0);
+ }
string strName;
GetInputName(m_session.get(), strName);
@@ -70,8 +106,8 @@
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- vocab = new Vocab(config_path.c_str());
- LoadCmvn(cmvn_path.c_str());
+ vocab = new Vocab(am_config.c_str());
+ LoadCmvn(am_cmvn.c_str());
}
Paraformer::~Paraformer()
@@ -113,6 +149,10 @@
void Paraformer::LoadCmvn(const char *filename)
{
ifstream cmvn_stream(filename);
+ if (!cmvn_stream.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << filename;
+ exit(0);
+ }
string line;
while (getline(cmvn_stream, line)) {
--
Gitblit v1.9.1