| | |
| | | LOG(ERROR) << "Error when load vad onnx model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_); |
| | | GetInputNames(vad_session_.get(), m_strInputNames, vad_in_names_); |
| | | GetOutputNames(vad_session_.get(), m_strOutputNames, vad_out_names_); |
| | | } |
| | | |
| | | void FsmnVad::GetInputOutputInfo( |
| | | const std::shared_ptr<Ort::Session> &session, |
| | | std::vector<const char *> *in_names, std::vector<const char *> *out_names) { |
| | | Ort::AllocatorWithDefaultOptions allocator; |
| | | // Input info |
| | | int num_nodes = session->GetInputCount(); |
| | | in_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator); |
| | | Ort::TypeInfo type_info = session->GetInputTypeInfo(i); |
| | | auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); |
| | | ONNXTensorElementDataType type = tensor_info.GetElementType(); |
| | | std::vector<int64_t> node_dims = tensor_info.GetShape(); |
| | | std::stringstream shape; |
| | | for (auto j: node_dims) { |
| | | shape << j; |
| | | shape << " "; |
| | | } |
| | | // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type |
| | | // << " dims=" << shape.str(); |
| | | (*in_names)[i] = name.get(); |
| | | name.release(); |
| | | } |
| | | // Output info |
| | | num_nodes = session->GetOutputCount(); |
| | | out_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator); |
| | | Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); |
| | | auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); |
| | | ONNXTensorElementDataType type = tensor_info.GetElementType(); |
| | | std::vector<int64_t> node_dims = tensor_info.GetShape(); |
| | | std::stringstream shape; |
| | | for (auto j: node_dims) { |
| | | shape << j; |
| | | shape << " "; |
| | | } |
| | | // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type |
| | | // << " dims=" << shape.str(); |
| | | (*out_names)[i] = name.get(); |
| | | name.release(); |
| | | } |
| | | } |
| | | |
| | | |
| | | void FsmnVad::Forward( |
| | | const std::vector<std::vector<float>> &chunk_feats, |