chenmengzheAAA
2023-09-14 2a66366be4c2715870e4859fd5a5db6e8a9dc00a
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -22,7 +22,7 @@
    }
    catch (std::exception const &e) {
        LOG(ERROR) << "Error when load punc onnx model: " << e.what();
        exit(0);
        exit(-1);
    }
    // read inputnames outputnames
    string strName;
@@ -181,11 +181,12 @@
        text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
    //vad_mask
    vector<float> arVadMask,arSubMask;
    // vector<float> arVadMask,arSubMask;
    vector<float> arVadMask;
    int nTextLength = input_data.size();
    VadMask(nTextLength, nCacheSize, arVadMask);
    Triangle(nTextLength, arSubMask);
    // Triangle(nTextLength, arSubMask);
    std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
    Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
        m_memoryInfo,
@@ -198,8 +199,8 @@
    std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
    Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
        m_memoryInfo,
        arSubMask.data(),
        arSubMask.size() ,
        arVadMask.data(),
        arVadMask.size(),
        SubMask_Dim.data(),
        SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
@@ -225,7 +226,6 @@
    catch (std::exception const &e)
    {
        LOG(ERROR) << "Error when run punc onnx forword: " << (e.what());
        exit(0);
    }
    return punction;
}