| | |
| | | LOG(ERROR) << "Error when load vad onnx model: " << e.what(); |
| | | exit(-1); |
| | | } |
| | | GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_, &vad_allocator); |
| | | GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_); |
| | | } |
| | | |
| | | void FsmnVad::GetInputOutputInfo( |
| | | const std::shared_ptr<Ort::Session> &session, |
| | | std::vector<const char *> *in_names, std::vector<const char *> *out_names, |
| | | Ort::AllocatorWithDefaultOptions *allocator) { |
| | | std::vector<const char *> *in_names, std::vector<const char *> *out_names) { |
| | | Ort::AllocatorWithDefaultOptions allocator; |
| | | // Input info |
| | | int num_nodes = session->GetInputCount(); |
| | | in_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, *allocator); |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator); |
| | | Ort::TypeInfo type_info = session->GetInputTypeInfo(i); |
| | | auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); |
| | | ONNXTensorElementDataType type = tensor_info.GetElementType(); |
| | |
| | | num_nodes = session->GetOutputCount(); |
| | | out_names->resize(num_nodes); |
| | | for (int i = 0; i < num_nodes; ++i) { |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, *allocator); |
| | | std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator); |
| | | Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); |
| | | auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); |
| | | ONNXTensorElementDataType type = tensor_info.GetElementType(); |
| | |
| | | } |
| | | |
| | | FsmnVad::~FsmnVad() { |
| | | for (auto vad_in_name_item : vad_in_names_) vad_allocator.Free((void*)vad_in_name_item); |
| | | for (auto vad_out_name_item : vad_out_names_) vad_allocator.Free((void*)vad_out_name_item); |
| | | } |
| | | |
| | | FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} { |