| | |
| | | LOG(ERROR) << "Error when load vad onnx model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_); |
| | | GetInputNames(vad_session_.get(), m_strInputNames, vad_in_names_); |
| | | GetOutputNames(vad_session_.get(), m_strOutputNames, vad_out_names_); |
| | | } |
| | | |
| | | void FsmnVad::GetInputOutputInfo( |
| | | const std::shared_ptr<Ort::Session> &session, |
| | | std::vector<const char *> *in_names, std::vector<const char *> *out_names) { |
| | | Ort::AllocatorWithDefaultOptions allocator; |
| | | // Input info |
| | | int num_nodes = session->GetInputCount(); |
| | | in_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator); |
| | | (*in_names)[i] = name.get(); |
| | | } |
| | | // Output info |
| | | num_nodes = session->GetOutputCount(); |
| | | out_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator); |
| | | (*out_names)[i] = name.get(); |
| | | } |
| | | } |
| | | |
| | | |
| | | void FsmnVad::Forward( |
| | | const std::vector<std::vector<float>> &chunk_feats, |