zhifu gao
2023-04-14 d7440147aabb2e8e6f411073ebac08f1e498e07e
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,7 +4,7 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
    string model_path;
    string cmvn_path;
    string config_path;
@@ -29,20 +29,20 @@
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
    string strName;
    getInputName(m_session, strName);
    getInputName(m_session.get(), strName);
    m_strInputNames.push_back(strName.c_str());
    getInputName(m_session, strName,1);
    getInputName(m_session.get(), strName,1);
    m_strInputNames.push_back(strName);
    
    getOutputName(m_session, strName);
    getOutputName(m_session.get(), strName);
    m_strOutputNames.push_back(strName);
    getOutputName(m_session, strName,1);
    getOutputName(m_session.get(), strName,1);
    m_strOutputNames.push_back(strName);
    for (auto& item : m_strInputNames)
@@ -55,11 +55,6 @@
ModelImp::~ModelImp()
{
    if (m_session)
    {
        delete m_session;
        m_session = nullptr;
    }
    if(vocab)
        delete vocab;
    fftwf_free(fft_input);
@@ -172,6 +167,12 @@
    apply_cmvn(in);
    Ort::RunOptions run_option;
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
    std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
    Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
        in->buff,