雾聪
2023-06-28 287fd0202b1590d29b77e962bc6ddb750aee5d04
add online func for funasrruntime.cpp
2个文件已修改
42 ■■■■ 已修改文件
funasr/runtime/onnxruntime/src/funasrruntime.cpp 41 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -23,9 +23,9 @@
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num)
    _FUNASRAPI FUNASR_HANDLE  CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type)
    {
        funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num);
        funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num, type);
        return mm;
    }
@@ -164,14 +164,28 @@
    }
    // APIs for PUNC Infer
    _FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
    _FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type, FUNASR_RESULT pre_result)
    {
        funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle;
        if (!punc_obj)
            return nullptr;
        FUNASR_RESULT p_result = nullptr;
        if (type==PUNC_OFFLINE){
            p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
            ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence);
        }else if(type==PUNC_ONLINE){
            if (!pre_result)
                p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT;
            else
                p_result = pre_result;
            ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence, ((funasr::FUNASR_PUNC_RESULT*)p_result)->arr_cache);
        }else{
            LOG(ERROR) << "Wrong PUNC_TYPE";
            exit(-1);
        }
        string punc_res = punc_obj->AddPunc(sz_sentence);
        return punc_res;
        return p_result;
    }
    // APIs for Offline-stream Infer
@@ -296,6 +310,15 @@
        return p_result->msg.c_str();
    }
    _FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
    {
        funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
        if(!p_result)
            return nullptr;
        return p_result->msg.c_str();
    }
    _FUNASRAPI vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index)
    {
        funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
@@ -314,6 +337,14 @@
        }
    }
    _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result)
    {
        if (result)
        {
            delete (funasr::FUNASR_PUNC_RESULT*)result;
        }
    }
    _FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result)
    {
        funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result;
funasr/runtime/onnxruntime/src/precomp.h
@@ -36,6 +36,7 @@
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "fsmn-vad-online.h"