yhliang
2023-05-11 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9
funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -5,6 +5,7 @@
#include "precomp.h"
namespace funasr {
CTTransformer::CTTransformer()
:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
{
@@ -54,7 +55,7 @@
    int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
    int nCurBatch = -1;
    int nSentEnd = -1, nLastCommaIndex = -1;
    vector<int64_t> RemainIDs; //
    vector<int32_t> RemainIDs; //
    vector<string> RemainStr; //
    vector<int> NewPunctuation; //
    vector<string> NewString; //
@@ -64,7 +65,7 @@
    for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
    {
        nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
        vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
        InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
        InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
@@ -141,12 +142,13 @@
    return strResult;
}
vector<int> CTTransformer::Infer(vector<int64_t> input_data)
vector<int> CTTransformer::Infer(vector<int32_t> input_data)
{
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    vector<int> punction;
    std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
    Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
    Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
        m_memoryInfo,
        input_data.data(),
        input_data.size(),
        input_shape_.data(),
@@ -185,3 +187,4 @@
    return punction;
}
} // namespace funasr