雾聪
2023-05-17 8706e767affc6bdc8cb7a67ca3a20a62779ff048
funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -5,6 +5,7 @@
#include "precomp.h"
namespace funasr {
CTTransformer::CTTransformer()
:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
{
@@ -40,7 +41,6 @@
   m_tokenizer.OpenYaml(punc_config.c_str());
}
CTTransformer::~CTTransformer()
{
}
@@ -55,7 +55,7 @@
    int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
    int nCurBatch = -1;
    int nSentEnd = -1, nLastCommaIndex = -1;
    vector<int64_t> RemainIDs; //
    vector<int32_t> RemainIDs; //
    vector<string> RemainStr; //
    vector<int> NewPunctuation; //
    vector<string> NewString; //
@@ -65,7 +65,7 @@
    for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
    {
        nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
        vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
        InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
        InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
@@ -142,12 +142,13 @@
    return strResult;
}
vector<int> CTTransformer::Infer(vector<int64_t> input_data)
vector<int> CTTransformer::Infer(vector<int32_t> input_data)
{
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    vector<int> punction;
    std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
    Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
    Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
        m_memoryInfo,
        input_data.data(),
        input_data.size(),
        input_shape_.data(),
@@ -180,10 +181,10 @@
    }
    catch (std::exception const &e)
    {
        printf(e.what());
        LOG(ERROR) << "Error when run punc onnx forword: " << (e.what());
        exit(0);
    }
    return punction;
}
} // namespace funasr