From 914de11b8b19258aa4b5f21de5f768233e89d4ba Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:40:07 +0800
Subject: [PATCH] add funasr-onnx-online-punc

---
 funasr/runtime/onnxruntime/include/punc-model.h             |    7 +-
 funasr/runtime/onnxruntime/bin/CMakeLists.txt               |    3 +
 funasr/runtime/onnxruntime/include/com-define.h             |    1 
 funasr/runtime/onnxruntime/include/funasrruntime.h          |   11 +++
 funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp |    6 +
 funasr/runtime/onnxruntime/src/commonfunc.h                 |    5 +
 funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp  |  130 +++++++++++++++++++++++++++++++++++++++++++
 7 files changed, 156 insertions(+), 7 deletions(-)

diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
index 962da0b..03c3a64 100644
--- a/funasr/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -12,5 +12,8 @@
 add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
 target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
 
+add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp")
+target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
+
 add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
 target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
index e18c27e..92c0525 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
@@ -84,11 +84,13 @@
     long taking_micros = 0;
     for(auto& txt_str : txt_list){
         gettimeofday(&start, NULL);
-        string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
         taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
-        LOG(INFO)<<"Results: "<<result;
+        string msg = FunASRGetResult(result, 0);
+        LOG(INFO)<<"Results: "<<msg;
+        CTTransformerFreeResult(result);
     }
 
     LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
new file mode 100644
index 0000000..c592616
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
@@ -0,0 +1,130 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+using namespace std;
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+    if (value_arg.isSet()){
+        model_path.insert({key, value_arg.getValue()});
+        LOG(INFO)<< key << " : " << value_arg.getValue();
+    }
+}
+
+void splitString(vector<string>& strings, const string& org_string, const string& seq) {
+	string::size_type p1 = 0;
+	string::size_type p2 = org_string.find(seq);
+
+	while (p2 != string::npos) {
+		if (p2 == p1) {
+			++p1;
+			p2 = org_string.find(seq, p1);
+			continue;
+		}
+		strings.push_back(org_string.substr(p1, p2 - p1));
+		p1 = p2 + seq.size();
+		p2 = org_string.find(seq, p1);
+	}
+
+	if (p1 != org_string.size()) {
+		strings.push_back(org_string.substr(p1));
+	}
+}
+
+int main(int argc, char *argv[])
+{
+    google::InitGoogleLogging(argv[0]);
+    FLAGS_logtostderr = true;
+
+    TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0");
+    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
+    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+    TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
+
+    cmd.add(model_dir);
+    cmd.add(quantize);
+    cmd.add(txt_path);
+    cmd.parse(argc, argv);
+
+    std::map<std::string, std::string> model_path;
+    GetValue(model_dir, MODEL_DIR, model_path);
+    GetValue(quantize, QUANTIZE, model_path);
+    GetValue(txt_path, TXT_PATH, model_path);
+
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    int thread_num = 1;
+    FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num, PUNC_ONLINE);
+
+    if (!punc_hanlde)
+    {
+        LOG(ERROR) << "FunASR init failed";
+        exit(-1);
+    }
+
+    gettimeofday(&end, NULL);
+    long seconds = (end.tv_sec - start.tv_sec);
+    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+    // read txt_path
+    vector<string> txt_list;
+
+    if(model_path.find(TXT_PATH)!=model_path.end()){
+        ifstream in(model_path.at(TXT_PATH));
+        if (!in.is_open()) {
+            LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
+            return 0;
+        }
+        string line;
+        while(getline(in, line))
+        {
+            txt_list.emplace_back(line); 
+        }
+        in.close();
+    }
+    
+    long taking_micros = 0;
+    for(auto& txt_str : txt_list){
+        vector<string> vad_strs;
+        splitString(vad_strs, txt_str, "|");
+        string str_out;
+        FUNASR_RESULT result = nullptr;
+        gettimeofday(&start, NULL);
+        for(auto& vad_str:vad_strs){
+            result=CTTransformerInfer(punc_hanlde, vad_str.c_str(), RASR_NONE, NULL, PUNC_ONLINE, result);
+            if(result){
+                string msg = CTTransformerGetResult(result, 0);
+                str_out += msg;
+                LOG(INFO)<<"Online result: "<<msg;
+            }
+        }
+        gettimeofday(&end, NULL);
+        seconds = (end.tv_sec - start.tv_sec);
+        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+        LOG(INFO)<<"Results: "<<str_out;
+        CTTransformerFreeResult(result);
+    }
+
+    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+    CTTransformerUninit(punc_hanlde);
+    return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 7a6345b..0d3aee0 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -69,6 +69,7 @@
 
 #define CANDIDATE_NUM   6
 #define UNKNOW_INDEX 0
+#define NOTPUNC  "_"
 #define NOTPUNC_INDEX 1
 #define COMMA_INDEX 2
 #define PERIOD_INDEX 3
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index af430f7..98727bd 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,6 +46,11 @@
 	FUNASR_MODEL_PARAFORMER = 3,
 }FUNASR_MODEL_TYPE;
 
+typedef enum {
+	PUNC_OFFLINE=0,
+	PUNC_ONLINE=1,
+}PUNC_TYPE;
+
 typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
 	
 // ASR
@@ -75,8 +80,10 @@
 _FUNASRAPI const float		FsmnVadGetRetSnippetTime(FUNASR_RESULT result);
 
 // PUNC
-_FUNASRAPI FUNASR_HANDLE  		CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num);
-_FUNASRAPI const std::string	CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
+_FUNASRAPI FUNASR_HANDLE  		CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
+_FUNASRAPI FUNASR_RESULT     	CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type=PUNC_OFFLINE, FUNASR_RESULT pre_result=nullptr);
+_FUNASRAPI const char* 			CTTransformerGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI void					CTTransformerFreeResult(FUNASR_RESULT result);
 _FUNASRAPI void					CTTransformerUninit(FUNASR_HANDLE handle);
 
 //OfflineStream
diff --git a/funasr/runtime/onnxruntime/include/punc-model.h b/funasr/runtime/onnxruntime/include/punc-model.h
index da7ff60..e50f9f7 100644
--- a/funasr/runtime/onnxruntime/include/punc-model.h
+++ b/funasr/runtime/onnxruntime/include/punc-model.h
@@ -5,16 +5,17 @@
 #include <string>
 #include <map>
 #include <vector>
+#include "funasrruntime.h"
 
 namespace funasr {
 class PuncModel {
   public:
     virtual ~PuncModel(){};
 	  virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
-	  virtual std::vector<int>  Infer(std::vector<int32_t> input_data)=0;
-	  virtual std::string AddPunc(const char* sz_input)=0;
+	  virtual std::string AddPunc(const char* sz_input){};
+	  virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){};
 };
 
-PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
+PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
 } // namespace funasr
 #endif
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index d0882c6..b74c1c1 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -14,6 +14,11 @@
     float  snippet_time;
 }FUNASR_VAD_RESULT;
 
+typedef struct
+{
+    string msg;
+    vector<string> arr_cache;
+}FUNASR_PUNC_RESULT;
 
 #ifdef _WIN32
 #include <codecvt>

--
Gitblit v1.9.1