Skip to content

Commit

Permalink
feats: add ROCM related ⬆️
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker2770 committed Aug 16, 2024
1 parent e08f1df commit 43da8c5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ if(ENABLE_TENSORRT)
add_definitions(-DUSE_TENSORRT)
endif(ENABLE_TENSORRT)

option(ENABLE_ROCM OFF)
if(ENABLE_ROCM)
message("build project with ROCM")
add_definitions(-DUSE_ROCM)
endif(ENABLE_ROCM)

include_directories("${ONNXRUNTIME_ROOTDIR}/include"
"${CMAKE_SOURCE_DIR}/src"
"${CMAKE_SOURCE_DIR}/src/pbrain-Z2I/toml11"
Expand Down
1 change: 1 addition & 0 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ SOFTWARE.
// #define USE_CUDA
// #define USE_OPENVINO
// #define USE_TENSORRT
// #define USE_ROCM

#define CHANNEL_SIZE 3

Expand Down
8 changes: 8 additions & 0 deletions src/onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ NeuralNetwork::NeuralNetwork(const std::string &model_path, const unsigned int b
CheckStatus(g_ort, g_ort->SessionOptionsAppendExecutionProvider_TensorRT(session_options, &TensorRT_Options));
#endif

#ifdef USE_ROCM
const OrtApiBase *ptr_api_base = OrtGetApiBase();
const OrtApi *g_ort = ptr_api_base->GetApi(ORT_API_VERSION);
OrtROCMProviderOptions ROCM_Options;
ROCM_Options.device_id = 0;
CheckStatus(g_ort, g_ort->SessionOptionsAppendExecutionProvider_ROCM(session_options, &ROCM_Options));
#endif

#ifdef _WIN32
// std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
std::wstring wstr(model_path.length(), L' ');
Expand Down
12 changes: 12 additions & 0 deletions xmake.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@ target("Z2I")
add_files("src/*.cpp")
add_packages("onnxruntime")
add_rpathdirs("@loader_path", "@loader_path/lib", "@executable_path", "@executable_path/lib")
-- add_defines("USE_CUDA")
-- add_defines("USE_OPENVINO")
-- add_defines("USE_TENSORRT")
-- add_defines("USE_ROCM")

target("pbrain-Z2I")
set_kind("binary")
add_files("src/pbrain-Z2I/*.cpp")
add_deps("Z2I")
add_packages("onnxruntime")
add_rpathdirs("@loader_path", "@loader_path/lib", "@executable_path", "@executable_path/lib")
-- add_defines("USE_CUDA")
-- add_defines("USE_OPENVINO")
-- add_defines("USE_TENSORRT")
-- add_defines("USE_ROCM")
after_build(function (target)
os.cp("$(scriptdir)/config/*.toml", target:targetdir())
end)
Expand All @@ -26,3 +34,7 @@ target("train_and_eval")
add_deps("Z2I")
add_packages("onnxruntime")
add_rpathdirs("@loader_path", "@loader_path/lib", "@executable_path", "@executable_path/lib")
-- add_defines("USE_CUDA")
-- add_defines("USE_OPENVINO")
-- add_defines("USE_TENSORRT")
-- add_defines("USE_ROCM")

0 comments on commit 43da8c5

Please sign in to comment.