From 1d205d340ff5129e457fa462eb5b31b152086339 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期二, 25 四月 2023 20:17:12 +0800
Subject: [PATCH] add option parser; add copyright;
---
funasr/runtime/onnxruntime/src/ct-transformer.cpp | 31 +-
funasr/runtime/onnxruntime/include/com-define.h | 19 +
funasr/runtime/onnxruntime/src/model.cpp | 4
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp | 112 ++++++++---
funasr/runtime/onnxruntime/src/online-feature.h | 4
funasr/runtime/onnxruntime/src/tokenizer.cpp | 7
funasr/runtime/onnxruntime/include/model.h | 5
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 63 ++++---
funasr/runtime/onnxruntime/src/paraformer.h | 12 +
funasr/runtime/onnxruntime/include/libfunasrapi.h | 11
funasr/runtime/onnxruntime/src/ct-transformer.h | 8
funasr/runtime/onnxruntime/src/tokenizer.h | 5
funasr/runtime/onnxruntime/src/e2e-vad.h | 4
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp | 133 ++++++++++----
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 28 +-
funasr/runtime/onnxruntime/src/online-feature.cpp | 4
funasr/runtime/onnxruntime/src/paraformer.cpp | 86 ++++++--
funasr/runtime/onnxruntime/src/fsmn-vad.h | 4
18 files changed, 367 insertions(+), 173 deletions(-)
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index e2c22f4..0369278 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -12,6 +12,18 @@
#define MODEL_SAMPLE_RATE 16000
#endif
+// model path
+#define VAD_MODEL_PATH "vad-model"
+#define VAD_CMVN_PATH "vad-cmvn"
+#define AM_MODEL_PATH "am-model"
+#define AM_CMVN_PATH "am-cmvn"
+#define AM_CONFIG_PATH "am-config"
+#define PUNC_MODEL_PATH "punc-model"
+#define PUNC_CONFIG_PATH "punc-config"
+#define WAV_PATH "wav-path"
+#define WAV_SCP "wav-scp"
+#define THREAD_NUM "thread-num"
+
// vad
#ifndef VAD_SILENCE_DYRATION
#define VAD_SILENCE_DYRATION 15000
@@ -26,14 +38,7 @@
#endif
// punc
-#define PUNC_MODEL_FILE "punc_model.onnx"
-#define PUNC_YAML_FILE "punc.yaml"
#define UNK_CHAR "<unk>"
-
-#define INPUT_NUM 2
-#define INPUT_NAME1 "input"
-#define INPUT_NAME2 "text_lengths"
-#define OUTPUT_NAME "logits"
#define TOKEN_LEN 20
#define CANDIDATE_NUM 6
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 6b6e148..8dca7f4 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -1,4 +1,5 @@
#pragma once
+#include <map>
#ifdef WIN32
#ifdef _FUNASR_API_EXPORT
@@ -47,13 +48,13 @@
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// APIs for funasr
-_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* sz_model_dir, int thread_num, bool quantize=false, bool use_vad=false, bool use_punc=false);
+_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
// if not give a fn_callback ,it should be NULL
-_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
+_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
diff --git a/funasr/runtime/onnxruntime/include/model.h b/funasr/runtime/onnxruntime/include/model.h
index 26a67f0..4b4b582 100644
--- a/funasr/runtime/onnxruntime/include/model.h
+++ b/funasr/runtime/onnxruntime/include/model.h
@@ -3,6 +3,7 @@
#define MODEL_H
#include <string>
+#include <map>
class Model {
public:
@@ -13,7 +14,9 @@
virtual std::string Rescoring() = 0;
virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
virtual std::string AddPunc(const char* sz_input)=0;
+ virtual bool UseVad() =0;
+ virtual bool UsePunc() =0;
};
-Model *CreateModel(const char *path,int thread_num=1,bool quantize=false, bool use_vad=false, bool use_punc=false);
+Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
#endif
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.cpp b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
index 3d66dcd..39ab2bc 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -1,28 +1,28 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#include "precomp.h"
-CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
+CTTransformer::CTTransformer()
:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
{
+}
+
+void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
session_options.SetIntraOpNumThreads(thread_num);
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options.DisableCpuMemArena();
- string strModelPath = PathAppend(sz_model_dir, PUNC_MODEL_FILE);
- string strYamlPath = PathAppend(sz_model_dir, PUNC_YAML_FILE);
-
try{
-#ifdef _WIN32
- std::wstring detPath = strToWstr(strModelPath);
- m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
-#else
- m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
-#endif
+ m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options);
}
- catch(exception e)
- {
- printf(e.what());
+ catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load punc onnx model: " << e.what();
+ exit(0);
}
- // read inputnames outputnamess
+ // read inputnames outputnames
string strName;
GetInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
@@ -37,9 +37,10 @@
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- m_tokenizer.OpenYaml(strYamlPath.c_str());
+ m_tokenizer.OpenYaml(punc_config.c_str());
}
+
CTTransformer::~CTTransformer()
{
}
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer.h b/funasr/runtime/onnxruntime/src/ct-transformer.h
index 77972c7..d965bb3 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer.h
+++ b/funasr/runtime/onnxruntime/src/ct-transformer.h
@@ -1,3 +1,8 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#pragma once
class CTTransformer {
@@ -19,7 +24,8 @@
Ort::SessionOptions session_options;
public:
- CTTransformer(const char* sz_model_dir, int thread_num);
+ CTTransformer();
+ void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
~CTTransformer();
vector<int> Infer(vector<int64_t> input_data);
string AddPunc(const char* sz_input);
diff --git a/funasr/runtime/onnxruntime/src/e2e-vad.h b/funasr/runtime/onnxruntime/src/e2e-vad.h
index e029dc3..90f2635 100644
--- a/funasr/runtime/onnxruntime/src/e2e-vad.h
+++ b/funasr/runtime/onnxruntime/src/e2e-vad.h
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include <utility>
#include <vector>
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 2988350..7360a9a 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include <fstream>
#include "precomp.h"
@@ -32,7 +36,7 @@
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model.c_str(), session_options_);
} catch (std::exception const &e) {
- LOG(ERROR) << "Error when load onnx model: " << e.what();
+ LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
}
LOG(INFO) << "vad onnx:";
@@ -161,36 +165,41 @@
void FsmnVad::LoadCmvn(const char *filename)
{
- using namespace std;
- ifstream cmvn_stream(filename);
- string line;
+ try{
+ using namespace std;
+ ifstream cmvn_stream(filename);
+ string line;
- while (getline(cmvn_stream, line)) {
- istringstream iss(line);
- vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
- if (line_item[0] == "<AddShift>") {
- getline(cmvn_stream, line);
- istringstream means_lines_stream(line);
- vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
- if (means_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ while (getline(cmvn_stream, line)) {
+ istringstream iss(line);
+ vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+ if (line_item[0] == "<AddShift>") {
+ getline(cmvn_stream, line);
+ istringstream means_lines_stream(line);
+ vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+ if (means_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < means_lines.size() - 1; j++) {
+ means_list.push_back(stof(means_lines[j]));
+ }
+ continue;
}
- continue;
+ }
+ else if (line_item[0] == "<Rescale>") {
+ getline(cmvn_stream, line);
+ istringstream vars_lines_stream(line);
+ vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+ if (vars_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < vars_lines.size() - 1; j++) {
+ // vars_list.push_back(stof(vars_lines[j])*scale);
+ vars_list.push_back(stof(vars_lines[j]));
+ }
+ continue;
+ }
}
}
- else if (line_item[0] == "<Rescale>") {
- getline(cmvn_stream, line);
- istringstream vars_lines_stream(line);
- vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
- if (vars_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
- }
- continue;
- }
- }
+ }catch(std::exception const &e) {
+ LOG(ERROR) << "Error when load vad cmvn : " << e.what();
+ exit(0);
}
}
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h
index e8569f9..a27ea0f 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#ifndef VAD_SERVER_FSMNVAD_H
#define VAD_SERVER_FSMNVAD_H
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
index 1d822a0..366d993 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#ifndef _WIN32
#include <sys/time.h>
@@ -5,7 +9,10 @@
#include <win_func.h>
#endif
+#include <glog/logging.h>
#include "libfunasrapi.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
#include <iostream>
#include <fstream>
@@ -14,9 +21,11 @@
#include <atomic>
#include <mutex>
#include <thread>
+#include <map>
+
using namespace std;
-std::atomic<int> index(0);
+std::atomic<int> wav_index(0);
std::mutex mtx;
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
@@ -35,7 +44,7 @@
while (true) {
// 浣跨敤鍘熷瓙鍙橀噺鑾峰彇绱㈠紩骞堕�掑
- int i = index.fetch_add(1);
+ int i = wav_index.fetch_add(1);
if (i >= wav_list.size()) {
break;
}
@@ -68,59 +77,94 @@
}
}
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
int main(int argc, char *argv[])
{
+ //google::InitGoogleLogging(argv[0]);
- if (argc < 5)
- {
- printf("Usage: %s /path/to/model_dir /path/to/wav.scp quantize(true or false) thread_num \n", argv[0]);
- exit(-1);
- }
+ TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
+ TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
- // read wav.scp
- vector<string> wav_list;
- ifstream in(argv[2]);
- if (!in.is_open()) {
- printf("Failed to open file: %s", argv[2]);
- return 0;
- }
- string line;
- while(getline(in, line))
- {
- istringstream iss(line);
- string column1, column2;
- iss >> column1 >> column2;
- wav_list.push_back(column2);
- }
- in.close();
+ TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", false, "", "string");
+ TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string");
+ TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", false, "", "string");
- // model init
+ TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+
+ TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", true, "", "string");
+ TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
+
+ cmd.add(vad_model);
+ cmd.add(vad_cmvn);
+ cmd.add(am_model);
+ cmd.add(am_cmvn);
+ cmd.add(am_config);
+ cmd.add(punc_model);
+ cmd.add(punc_config);
+ cmd.add(wav_scp);
+ cmd.add(thread_num);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(vad_model, VAD_MODEL_PATH, model_path);
+ GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
+ GetValue(am_model, AM_MODEL_PATH, model_path);
+ GetValue(am_cmvn, AM_CMVN_PATH, model_path);
+ GetValue(am_config, AM_CONFIG_PATH, model_path);
+ GetValue(punc_model, PUNC_MODEL_PATH, model_path);
+ GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(wav_scp, WAV_SCP, model_path);
+
struct timeval start, end;
gettimeofday(&start, NULL);
- // is quantize
- bool quantize = false;
- istringstream(argv[3]) >> boolalpha >> quantize;
- // thread num
- int thread_num = 1;
- thread_num = atoi(argv[4]);
+ FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1);
- FUNASR_HANDLE asr_handle=FunASRInit(argv[1], 1, quantize);
if (!asr_handle)
{
- printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
+ LOG(ERROR) << ("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1);
}
+
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
- printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
+ printf("Model initialization takes %lfs.", (double)modle_init_micros / 1000000);
+
+ // read wav_scp
+ vector<string> wav_list;
+ if(model_path.find(WAV_SCP)!=model_path.end()){
+ ifstream in(model_path.at(WAV_SCP));
+ if (!in.is_open()) {
+ LOG(ERROR) << ("Failed to open file: %s", model_path.at(WAV_SCP));
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ }
+ in.close();
+ }
// 澶氱嚎绋嬫祴璇�
float total_length = 0.0f;
long total_time = 0;
std::vector<std::thread> threads;
- for (int i = 0; i < thread_num; i++)
+ int rtf_threds = thread_num.getValue();
+ for (int i = 0; i < rtf_threds; i++)
{
threads.emplace_back(thread(runReg, asr_handle, wav_list, &total_length, &total_time, i));
}
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
index 80c50ab..bc6224c 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#ifndef _WIN32
#include <sys/time.h>
@@ -5,82 +9,133 @@
#include <win_func.h>
#endif
+#include <iostream>
+#include <fstream>
#include <sstream>
+#include <map>
#include <glog/logging.h>
#include "libfunasrapi.h"
#include "tclap/CmdLine.h"
+#include "com-define.h"
using namespace std;
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
int main(int argc, char *argv[])
{
- google::InitGoogleLogging(argv[0]);
+ //google::InitGoogleLogging(argv[0]);
- TCLAP::CmdLine cmd("Command description message", ' ', "1.0");
- TCLAP::ValueArg<std::string> nameArg("n", "name", "Name of user", true, "", "string");
- TCLAP::SwitchArg reverseSwitch("r","reverse","Print name backwards", cmd, false);
- cmd.add(nameArg);
+ TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
+ TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
+ TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", false, "", "string");
+ TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string");
+ TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", false, "", "string");
+
+ TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
+ TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
+
+ cmd.add(vad_model);
+ cmd.add(vad_cmvn);
+ cmd.add(am_model);
+ cmd.add(am_cmvn);
+ cmd.add(am_config);
+ cmd.add(punc_model);
+ cmd.add(punc_config);
+ cmd.add(wav_path);
+ cmd.add(wav_scp);
cmd.parse(argc, argv);
- string name = nameArg.getValue();
- printf(name.c_str());
+ std::map<std::string, std::string> model_path;
+ GetValue(vad_model, VAD_MODEL_PATH, model_path);
+ GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
+ GetValue(am_model, AM_MODEL_PATH, model_path);
+ GetValue(am_cmvn, AM_CMVN_PATH, model_path);
+ GetValue(am_config, AM_CONFIG_PATH, model_path);
+ GetValue(punc_model, PUNC_MODEL_PATH, model_path);
+ GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+ GetValue(wav_scp, WAV_SCP, model_path);
- if (argc < 6)
- {
- printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) use_punc(true or false)\n", argv[0]);
- exit(-1);
- }
struct timeval start, end;
gettimeofday(&start, NULL);
int thread_num = 1;
- // is quantize
- bool quantize = false;
- bool use_vad = false;
- bool use_punc = false;
- istringstream(argv[3]) >> boolalpha >> quantize;
- istringstream(argv[4]) >> boolalpha >> use_vad;
- istringstream(argv[5]) >> boolalpha >> use_punc;
- FUNASR_HANDLE asr_hanlde=FunASRInit(argv[1], thread_num, quantize, use_vad, use_punc);
+ FUNASR_HANDLE asr_hanlde=FunASRInit(model_path, thread_num);
if (!asr_hanlde)
{
- printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
+ LOG(ERROR) << ("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
exit(-1);
}
gettimeofday(&end, NULL);
long seconds = (end.tv_sec - start.tv_sec);
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
- printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
+ printf("Model initialization takes %lfs.", (double)modle_init_micros / 1000000);
- gettimeofday(&start, NULL);
- FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
- gettimeofday(&end, NULL);
+ // read wav_path and wav_scp
+ vector<string> wav_list;
- float snippet_time = 0.0f;
- if (result)
- {
- string msg = FunASRGetResult(result, 0);
- setbuf(stdout, NULL);
- printf("Result: %s \n", msg.c_str());
- snippet_time = FunASRGetRetSnippetTime(result);
- FunASRFreeResult(result);
+ if(model_path.find(WAV_PATH)!=model_path.end()){
+ wav_list.emplace_back(model_path.at(WAV_PATH));
}
- else
- {
- printf("no return data!\n");
+ if(model_path.find(WAV_SCP)!=model_path.end()){
+ ifstream in(model_path.at(WAV_SCP));
+ if (!in.is_open()) {
+ LOG(ERROR) << ("Failed to open file: %s", model_path.at(WAV_SCP));
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ }
+ in.close();
+ }
+
+ float snippet_time = 0.0f;
+ long taking_micros = 0;
+ for(auto& wav_file : wav_list){
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ string msg = FunASRGetResult(result, 0);
+ setbuf(stdout, NULL);
+ printf("Result: %s \n", msg.c_str());
+ snippet_time += FunASRGetRetSnippetTime(result);
+ FunASRFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("no return data!\n");
+ }
}
printf("Audio length %lfs.\n", (double)snippet_time);
- seconds = (end.tv_sec - start.tv_sec);
- long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
FunASRUninit(asr_hanlde);
-
return 0;
}
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index 10c061e..93434bb 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -5,13 +5,13 @@
#endif
// APIs for funasr
- _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* sz_model_dir, int thread_num, bool quantize, bool use_vad, bool use_punc)
+ _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
{
- Model* mm = CreateModel(sz_model_dir, thread_num, quantize, use_vad, use_punc);
+ Model* mm = CreateModel(model_path, thread_num);
return mm;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
{
Model* recog_obj = (Model*)handle;
if (!recog_obj)
@@ -21,7 +21,7 @@
Audio audio(1);
if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
return nullptr;
- if(use_vad){
+ if(recog_obj->UseVad()){
audio.Split(recog_obj);
}
@@ -39,7 +39,7 @@
if (fn_callback)
fn_callback(n_step, n_total);
}
- if(use_punc){
+ if(recog_obj->UsePunc()){
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
@@ -47,7 +47,7 @@
return p_result;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
{
Model* recog_obj = (Model*)handle;
if (!recog_obj)
@@ -56,7 +56,7 @@
Audio audio(1);
if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
return nullptr;
- if(use_vad){
+ if(recog_obj->UseVad()){
audio.Split(recog_obj);
}
@@ -74,7 +74,7 @@
if (fn_callback)
fn_callback(n_step, n_total);
}
- if(use_punc){
+ if(recog_obj->UsePunc()){
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
@@ -82,7 +82,7 @@
return p_result;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback)
{
Model* recog_obj = (Model*)handle;
if (!recog_obj)
@@ -91,7 +91,7 @@
Audio audio(1);
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
return nullptr;
- if(use_vad){
+ if(recog_obj->UseVad()){
audio.Split(recog_obj);
}
@@ -109,7 +109,7 @@
if (fn_callback)
fn_callback(n_step, n_total);
}
- if(use_punc){
+ if(recog_obj->UsePunc()){
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
@@ -117,7 +117,7 @@
return p_result;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad, bool use_punc)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
{
Model* recog_obj = (Model*)handle;
if (!recog_obj)
@@ -127,7 +127,7 @@
Audio audio(1);
if(!audio.LoadWav(sz_wavfile, &sampling_rate))
return nullptr;
- if(use_vad){
+ if(recog_obj->UseVad()){
audio.Split(recog_obj);
}
@@ -145,7 +145,7 @@
if (fn_callback)
fn_callback(n_step, n_total);
}
- if(use_punc){
+ if(recog_obj->UsePunc()){
string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
p_result->msg = punc_res;
}
diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp
index a582f82..52ce7ba 100644
--- a/funasr/runtime/onnxruntime/src/model.cpp
+++ b/funasr/runtime/onnxruntime/src/model.cpp
@@ -1,8 +1,8 @@
#include "precomp.h"
-Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
+Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
{
Model *mm;
- mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
+ mm = new paraformer::Paraformer(model_path, thread_num);
return mm;
}
diff --git a/funasr/runtime/onnxruntime/src/online-feature.cpp b/funasr/runtime/onnxruntime/src/online-feature.cpp
index 36e2770..3f57e0b 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.cpp
+++ b/funasr/runtime/onnxruntime/src/online-feature.cpp
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include "online-feature.h"
#include <utility>
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
index 78245de..decaaf4 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ b/funasr/runtime/onnxruntime/src/online-feature.h
@@ -1,3 +1,7 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include <vector>
#include "precomp.h"
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 72127f8..6d1a909 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,70 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#include "precomp.h"
using namespace std;
using namespace paraformer;
-Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
- string model_path;
- string cmvn_path;
- string config_path;
// VAD model
- if(use_vad){
- string vad_path = PathAppend(path, "vad_model.onnx");
- string mvn_path = PathAppend(path, "vad.mvn");
+ if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
+ use_vad = true;
+ string vad_model_path;
+ string vad_cmvn_path;
+
+ try{
+ vad_model_path = model_path.at(VAD_MODEL_PATH);
+ vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH <<" :" << e.what();
+ exit(0);
+ }
vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+ }
+
+ // AM model
+ if(model_path.find(AM_MODEL_PATH) != model_path.end()){
+ string am_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ try{
+ am_model_path = model_path.at(AM_MODEL_PATH);
+ am_cmvn_path = model_path.at(AM_CMVN_PATH);
+ am_config_path = model_path.at(AM_CONFIG_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
+ exit(0);
+ }
+ InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
}
// PUNC model
- if(use_punc){
- punc_handle = make_unique<CTTransformer>(path, thread_num);
- }
+ if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
+ use_punc = true;
+ string punc_model_path;
+ string punc_config_path;
+
+ try{
+ punc_model_path = model_path.at(PUNC_MODEL_PATH);
+ punc_config_path = model_path.at(PUNC_CONFIG_PATH);
+ }catch(const out_of_range& e){
+ LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
+ exit(0);
+ }
- if(quantize)
- {
- model_path = PathAppend(path, "model_quant.onnx");
- }else{
- model_path = PathAppend(path, "model.onnx");
+ punc_handle = make_unique<CTTransformer>();
+ punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
}
- cmvn_path = PathAppend(path, "am.mvn");
- config_path = PathAppend(path, "config.yaml");
+}
+void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +82,12 @@
// DisableCpuMemArena can improve performance
session_options.DisableCpuMemArena();
-#ifdef _WIN32
- wstring wstrPath = strToWstr(model_path);
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#else
- m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#endif
+ try {
+ m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am onnx model: " << e.what();
+ exit(0);
+ }
string strName;
GetInputName(m_session.get(), strName);
@@ -70,8 +104,8 @@
m_szInputNames.push_back(item.c_str());
for (auto& item : m_strOutputNames)
m_szOutputNames.push_back(item.c_str());
- vocab = new Vocab(config_path.c_str());
- LoadCmvn(cmvn_path.c_str());
+ vocab = new Vocab(am_config.c_str());
+ LoadCmvn(am_cmvn.c_str());
}
Paraformer::~Paraformer()
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 5301932..f3eb059 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -1,3 +1,8 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#pragma once
@@ -41,10 +46,13 @@
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
+ bool use_vad=false;
+ bool use_punc=false;
public:
- Paraformer(const char* path, int thread_num=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
+ Paraformer(std::map<std::string, std::string>& model_path, int thread_num=0);
~Paraformer();
+ void InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void Reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
string ForwardChunk(float* din, int len, int flag);
@@ -52,6 +60,8 @@
string Rescoring();
std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
string AddPunc(const char* sz_input);
+ bool UseVad(){return use_vad;};
+ bool UsePunc(){return use_punc;};
};
} // namespace paraformer
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index ef0c533..c38664a 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -1,4 +1,9 @@
- #include "precomp.h"
+ /**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
{
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h
index 7326db8..4ff1809 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.h
+++ b/funasr/runtime/onnxruntime/src/tokenizer.h
@@ -1,3 +1,8 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
#pragma once
#include <yaml-cpp/yaml.h>
--
Gitblit v1.9.1