From 914de11b8b19258aa4b5f21de5f768233e89d4ba Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:40:07 +0800
Subject: [PATCH] add funasr-onnx-online-punc
---
funasr/runtime/onnxruntime/include/punc-model.h | 7 +-
funasr/runtime/onnxruntime/bin/CMakeLists.txt | 3 +
funasr/runtime/onnxruntime/include/com-define.h | 1
funasr/runtime/onnxruntime/include/funasrruntime.h | 11 +++
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp | 6 +
funasr/runtime/onnxruntime/src/commonfunc.h | 5 +
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp | 130 +++++++++++++++++++++++++++++++++++++++++++
7 files changed, 156 insertions(+), 7 deletions(-)
diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
index 962da0b..03c3a64 100644
--- a/funasr/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -12,5 +12,8 @@
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
+add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp")
+target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
+
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
index e18c27e..92c0525 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
@@ -84,11 +84,13 @@
long taking_micros = 0;
for(auto& txt_str : txt_list){
gettimeofday(&start, NULL);
- string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
+ FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
- LOG(INFO)<<"Results: "<<result;
+ string msg = FunASRGetResult(result, 0);
+ LOG(INFO)<<"Results: "<<msg;
+ CTTransformerFreeResult(result);
}
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
new file mode 100644
index 0000000..c592616
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
@@ -0,0 +1,130 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+using namespace std;
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+void splitString(vector<string>& strings, const string& org_string, const string& seq) {
+ string::size_type p1 = 0;
+ string::size_type p2 = org_string.find(seq);
+
+ while (p2 != string::npos) {
+ if (p2 == p1) {
+ ++p1;
+ p2 = org_string.find(seq, p1);
+ continue;
+ }
+ strings.push_back(org_string.substr(p1, p2 - p1));
+ p1 = p2 + seq.size();
+ p2 = org_string.find(seq, p1);
+ }
+
+ if (p1 != org_string.size()) {
+ strings.push_back(org_string.substr(p1));
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(txt_path);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(txt_path, TXT_PATH, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = 1;
+ FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num, PUNC_ONLINE);
+
+ if (!punc_hanlde)
+ {
+ LOG(ERROR) << "FunASR init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read txt_path
+ vector<string> txt_list;
+
+ if(model_path.find(TXT_PATH)!=model_path.end()){
+ ifstream in(model_path.at(TXT_PATH));
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ txt_list.emplace_back(line);
+ }
+ in.close();
+ }
+
+ long taking_micros = 0;
+ for(auto& txt_str : txt_list){
+ vector<string> vad_strs;
+ splitString(vad_strs, txt_str, "|");
+ string str_out;
+ FUNASR_RESULT result = nullptr;
+ gettimeofday(&start, NULL);
+ for(auto& vad_str:vad_strs){
+ result=CTTransformerInfer(punc_hanlde, vad_str.c_str(), RASR_NONE, NULL, PUNC_ONLINE, result);
+ if(result){
+ string msg = CTTransformerGetResult(result, 0);
+ str_out += msg;
+ LOG(INFO)<<"Online result: "<<msg;
+ }
+ }
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO)<<"Results: "<<str_out;
+ CTTransformerFreeResult(result);
+ }
+
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ CTTransformerUninit(punc_hanlde);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 7a6345b..0d3aee0 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -69,6 +69,7 @@
#define CANDIDATE_NUM 6
#define UNKNOW_INDEX 0
+#define NOTPUNC "_"
#define NOTPUNC_INDEX 1
#define COMMA_INDEX 2
#define PERIOD_INDEX 3
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index af430f7..98727bd 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,6 +46,11 @@
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
+typedef enum {
+ PUNC_OFFLINE=0,
+ PUNC_ONLINE=1,
+}PUNC_TYPE;
+
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// ASR
@@ -75,8 +80,10 @@
_FUNASRAPI const float FsmnVadGetRetSnippetTime(FUNASR_RESULT result);
// PUNC
-_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num);
-_FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
+_FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type=PUNC_OFFLINE, FUNASR_RESULT pre_result=nullptr);
+_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result);
_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h
index da7ff60..e50f9f7 100644
--- a/funasr/runtime/onnxruntime/include/punc-model.h
+++ b/funasr/runtime/onnxruntime/include/punc-model.h
@@ -5,16 +5,17 @@
#include <string>
#include <map>
#include <vector>
+#include "funasrruntime.h"
namespace funasr {
class PuncModel {
public:
virtual ~PuncModel(){};
virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
- virtual std::vector<int> Infer(std::vector<int32_t> input_data)=0;
- virtual std::string AddPunc(const char* sz_input)=0;
+ virtual std::string AddPunc(const char* sz_input){};
+ virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){};
};
-PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index d0882c6..b74c1c1 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -14,6 +14,11 @@
float snippet_time;
}FUNASR_VAD_RESULT;
+typedef struct
+{
+ string msg;
+ vector<string> arr_cache;
+}FUNASR_PUNC_RESULT;
#ifdef _WIN32
#include <codecvt>
--
Gitblit v1.9.1