雾聪
2024-10-10 1480dcf5d571c4920b4f18717d580646794b8d28
runtime/onnxruntime/src/commonfunc.h
@@ -65,6 +65,19 @@
    }
}
inline void GetInputNames(Ort::Session* session, std::vector<std::string> &m_strInputNames,
                   std::vector<const char *> &m_szInputNames) {
    Ort::AllocatorWithDefaultOptions allocator;
    size_t numNodes = session->GetInputCount();
    m_strInputNames.resize(numNodes);
    m_szInputNames.resize(numNodes);
    for (size_t i = 0; i != numNodes; ++i) {
        auto t = session->GetInputNameAllocated(i, allocator);
        m_strInputNames[i] = t.get();
        m_szInputNames[i] = m_strInputNames[i].c_str();
    }
}
inline void GetOutputName(Ort::Session* session, string& outputName, int nIndex = 0) {
    size_t numOutputNodes = session->GetOutputCount();
    if (numOutputNodes > 0) {
@@ -76,6 +89,19 @@
    }
}
inline void GetOutputNames(Ort::Session* session, std::vector<std::string> &m_strOutputNames,
                   std::vector<const char *> &m_szOutputNames) {
    Ort::AllocatorWithDefaultOptions allocator;
    size_t numNodes = session->GetOutputCount();
    m_strOutputNames.resize(numNodes);
    m_szOutputNames.resize(numNodes);
    for (size_t i = 0; i != numNodes; ++i) {
        auto t = session->GetOutputNameAllocated(i, allocator);
        m_strOutputNames[i] = t.get();
        m_szOutputNames[i] = m_strOutputNames[i].c_str();
    }
}
template <class ForwardIterator>
inline static size_t Argmax(ForwardIterator first, ForwardIterator last) {
    return std::distance(first, std::max_element(first, last));