雾聪
2023-06-28 914de11b8b19258aa4b5f21de5f768233e89d4ba
add funasr-onnx-online-punc
6个文件已修改
1个文件已添加
163 ■■■■■ 已修改文件
funasr/runtime/onnxruntime/bin/CMakeLists.txt 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/com-define.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/funasrruntime.h 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/punc-model.h 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/commonfunc.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -12,5 +12,8 @@
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
@@ -84,11 +84,13 @@
    long taking_micros = 0;
    for(auto& txt_str : txt_list){
        gettimeofday(&start, NULL);
        string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
        FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
        LOG(INFO)<<"Results: "<<result;
        string msg = FunASRGetResult(result, 0);
        LOG(INFO)<<"Results: "<<msg;
        CTTransformerFreeResult(result);
    }
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
New file
@@ -0,0 +1,130 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <glog/logging.h>
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
using namespace std;
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
    if (value_arg.isSet()){
        model_path.insert({key, value_arg.getValue()});
        LOG(INFO)<< key << " : " << value_arg.getValue();
    }
}
void splitString(vector<string>& strings, const string& org_string, const string& seq) {
    string::size_type p1 = 0;
    string::size_type p2 = org_string.find(seq);
    while (p2 != string::npos) {
        if (p2 == p1) {
            ++p1;
            p2 = org_string.find(seq, p1);
            continue;
        }
        strings.push_back(org_string.substr(p1, p2 - p1));
        p1 = p2 + seq.size();
        p2 = org_string.find(seq, p1);
    }
    if (p1 != org_string.size()) {
        strings.push_back(org_string.substr(p1));
    }
}
int main(int argc, char *argv[])
{
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(txt_path);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(txt_path, TXT_PATH, model_path);
    struct timeval start, end;
    gettimeofday(&start, NULL);
    int thread_num = 1;
    FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num, PUNC_ONLINE);
    if (!punc_hanlde)
    {
        LOG(ERROR) << "FunASR init failed";
        exit(-1);
    }
    gettimeofday(&end, NULL);
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
    // read txt_path
    vector<string> txt_list;
    if(model_path.find(TXT_PATH)!=model_path.end()){
        ifstream in(model_path.at(TXT_PATH));
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            txt_list.emplace_back(line);
        }
        in.close();
    }
    long taking_micros = 0;
    for(auto& txt_str : txt_list){
        vector<string> vad_strs;
        splitString(vad_strs, txt_str, "|");
        string str_out;
        FUNASR_RESULT result = nullptr;
        gettimeofday(&start, NULL);
        for(auto& vad_str:vad_strs){
            result=CTTransformerInfer(punc_hanlde, vad_str.c_str(), RASR_NONE, NULL, PUNC_ONLINE, result);
            if(result){
                string msg = CTTransformerGetResult(result, 0);
                str_out += msg;
                LOG(INFO)<<"Online result: "<<msg;
            }
        }
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
        LOG(INFO)<<"Results: "<<str_out;
        CTTransformerFreeResult(result);
    }
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    CTTransformerUninit(punc_hanlde);
    return 0;
}
funasr/runtime/onnxruntime/include/com-define.h
@@ -69,6 +69,7 @@
#define CANDIDATE_NUM   6
#define UNKNOW_INDEX 0
#define NOTPUNC  "_"
#define NOTPUNC_INDEX 1
#define COMMA_INDEX 2
#define PERIOD_INDEX 3
funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,6 +46,11 @@
    FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef enum {
    PUNC_OFFLINE=0,
    PUNC_ONLINE=1,
}PUNC_TYPE;
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
    
// ASR
@@ -75,8 +80,10 @@
_FUNASRAPI const float        FsmnVadGetRetSnippetTime(FUNASR_RESULT result);
// PUNC
_FUNASRAPI FUNASR_HANDLE          CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI const std::string    CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_HANDLE          CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
_FUNASRAPI FUNASR_RESULT         CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type=PUNC_OFFLINE, FUNASR_RESULT pre_result=nullptr);
_FUNASRAPI const char*             CTTransformerGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI void                    CTTransformerFreeResult(FUNASR_RESULT result);
_FUNASRAPI void                    CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
funasr/runtime/onnxruntime/include/punc-model.h
@@ -5,16 +5,17 @@
#include <string>
#include <map>
#include <vector>
#include "funasrruntime.h"
namespace funasr {
class PuncModel {
  public:
    virtual ~PuncModel(){};
      virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
      virtual std::vector<int>  Infer(std::vector<int32_t> input_data)=0;
      virtual std::string AddPunc(const char* sz_input)=0;
      virtual std::string AddPunc(const char* sz_input){};
      virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){};
};
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE);
} // namespace funasr
#endif
funasr/runtime/onnxruntime/src/commonfunc.h
@@ -14,6 +14,11 @@
    float  snippet_time;
}FUNASR_VAD_RESULT;
typedef struct
{
    string msg;
    vector<string> arr_cache;
}FUNASR_PUNC_RESULT;
#ifdef _WIN32
#include <codecvt>