From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward

---
 runtime/onnxruntime/CMakeLists.txt |   24 ++++++++++++++++++------
 1 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt
index ab0e842..3756dd6 100644
--- a/runtime/onnxruntime/CMakeLists.txt
+++ b/runtime/onnxruntime/CMakeLists.txt
@@ -4,6 +4,7 @@
 
 option(ENABLE_GLOG "Whether to build glog" ON)
 option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
+option(GPU "Whether to build with GPU" OFF)
 
 # set(CMAKE_CXX_STANDARD 11)
 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@@ -20,12 +21,12 @@
 
 # for onnxruntime
 IF(WIN32)
-#    file(REMOVE ${PROJECT_SOURCE_DIR}/third_party/glog/src/config.h 
-#                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/export.h 
-#                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/logging.h 
-#                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/raw_logging.h 
-#                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/stl_logging.h 
-#                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/vlog_is_on.h)
+    file(REMOVE ${PROJECT_SOURCE_DIR}/third_party/glog/src/config.h 
+                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/export.h 
+                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/logging.h 
+                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/raw_logging.h 
+                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/stl_logging.h 
+                ${PROJECT_SOURCE_DIR}/third_party/glog/src/glog/vlog_is_on.h)
 ELSE()
     link_directories(${ONNXRUNTIME_DIR}/lib)
     link_directories(${FFMPEG_DIR}/lib)
@@ -37,6 +38,17 @@
 include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/include)
 include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
 
+if(GPU)
+    add_definitions(-DUSE_GPU)
+    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
+    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
+    include_directories(${TORCH_DIR}/include)
+    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
+    link_directories(${TORCH_DIR}/lib)
+    link_directories(${TORCH_BLADE_DIR})
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
+
 if(ENABLE_GLOG)
     include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
     set(BUILD_TESTING OFF)

--
Gitblit v1.9.1