Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nvtx equivalent for rocm #940

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ if(SIRIUS_USE_ROCM)
if (NOT DEFINED CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES gfx801 gfx900 gfx90a)
endif()
if(SIRIUS_USE_NVTX)
find_package(RocTX REQUIRED)
find_package(Roctracer REQUIRED)
endif()

enable_language(HIP)
find_package(rocblas CONFIG REQUIRED)
Expand Down
22 changes: 22 additions & 0 deletions cmake/modules/FindRocTX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# find roctracer roctx

include(FindPackageHandleStandardArgs)

find_library(SIRIUS_ROCTX_LIBRARIES
NAMES roctx64 roctx
PATH_SUFFIXES lib
DOC "roctx libraries list")

find_path(SIRIUS_ROCTX_INCLUDE_DIR
NAMES roctx.h
PATH_SUFFIXES include include/roctracer
)

find_package_handle_standard_args(RocTX "DEFAULT_MSG" SIRIUS_ROCTX_LIBRARIES SIRIUS_ROCTX_INCLUDE_DIR)

if(RocTX_FOUND AND NOT TARGET sirius::roctx)
add_library(sirius::roctx INTERFACE IMPORTED)
set_target_properties(sirius::roctx PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${SIRIUS_ROCTX_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "${SIRIUS_ROCTX_LIBRARIES}")
endif()
22 changes: 22 additions & 0 deletions cmake/modules/FindRoctracer.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# find roctracer roctracer

include(FindPackageHandleStandardArgs)

find_library(SIRIUS_ROCTRACER_LIBRARIES
NAMES roctracer64 roctracer
PATH_SUFFIXES lib
DOC "roctracer libraries list")

find_path(SIRIUS_ROCTRACER_INCLUDE_DIR
NAMES roctracer.h
PATH_SUFFIXES include include/roctracer
)

find_package_handle_standard_args(Roctracer "DEFAULT_MSG" SIRIUS_ROCTRACER_LIBRARIES SIRIUS_ROCTRACER_INCLUDE_DIR)

if(Roctracer_FOUND AND NOT TARGET sirius::roctracer)
add_library(sirius::roctracer INTERFACE IMPORTED)
set_target_properties(sirius::roctracer PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${SIRIUS_ROCTRACER_INCLUDE_DIR}"
INTERFACE_LINK_LIBRARIES "${SIRIUS_ROCTRACER_LIBRARIES}")
endif()
6 changes: 5 additions & 1 deletion spack/packages/sirius/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ class Sirius(CMakePackage, CudaPackage, ROCmPackage):
variant(
"profiler", default=True, description="Use internal profiler to measure execution time"
)
variant("nvtx", default=False, description="Use NVTX profiler")
variant(
"nvtx", default=False, description="Use NVTX/ROCTX profiler"
)

with when("@7.6:"):
variant(
Expand Down Expand Up @@ -163,6 +165,7 @@ class Sirius(CMakePackage, CudaPackage, ROCmPackage):
depends_on("nlcglib+cuda", when="+nlcglib+cuda")

depends_on("[email protected]:+mpi", when="+vdwxc")
depends_on("roctracer-dev", when="+nvtx+rocm")

depends_on("scalapack", when="+scalapack")

Expand All @@ -184,6 +187,7 @@ class Sirius(CMakePackage, CudaPackage, ROCmPackage):
conflicts("^[email protected]") # known to produce incorrect results
conflicts("+single_precision", when="@:7.2.4")
conflicts("+scalapack", when="^cray-libsci")
conflicts("+nvtx", when="~cuda~rocm")

# Propagate openmp to blas
depends_on("openblas threads=openmp", when="+openmp ^[virtuals=blas,lapack] openblas")
Expand Down
6 changes: 4 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ target_link_libraries(sirius_cxx PUBLIC ${GSL_LIBRARY}
"${SIRIUS_LINALG_LIB}"
$<$<BOOL:${SIRIUS_USE_PUGIXML}>:pugixml::pugixml>
$<$<BOOL:${SIRIUS_USE_MEMORY_POOL}>:umpire>
$<$<BOOL:${SIRIUS_USE_NVTX}>:nvToolsExt>
$<TARGET_NAME_IF_EXISTS:CUDA::nvToolsExt>
$<TARGET_NAME_IF_EXISTS:sirius::roctx>
$<TARGET_NAME_IF_EXISTS:sirius::roctracer>
$<TARGET_NAME_IF_EXISTS:sirius::cudalibs>
$<$<BOOL:${SIRIUS_USE_ROCM}>:roc::rocsolver>
$<$<BOOL:${SIRIUS_USE_ROCM}>:roc::rocblas>
Expand All @@ -137,7 +139,7 @@ target_compile_definitions(sirius_cxx PUBLIC
$<$<BOOL:${SIRIUS_USE_ELPA}>:SIRIUS_ELPA>
$<$<BOOL:${SIRIUS_USE_NLCGLIB}>:SIRIUS_NLCGLIB>
$<$<BOOL:${SIRIUS_USE_CUDA}>:SIRIUS_GPU SIRIUS_CUDA>
$<$<BOOL:${SIRIUS_USE_NVTX}>:SIRIUS_CUDA_NVTX>
$<$<BOOL:${SIRIUS_USE_NVTX}>:SIRIUS_TX>
$<$<BOOL:${SIRIUS_USE_MAGMA}>:SIRIUS_MAGMA>
$<$<BOOL:${SIRIUS_USE_ROCM}>:SIRIUS_GPU SIRIUS_ROCM>
$<$<BOOL:${SIRIUS_USE_VDWXC}>:SIRIUS_USE_VDWXC>
Expand Down
77 changes: 0 additions & 77 deletions src/core/acc/nvtx_profiler.hpp

This file was deleted.

135 changes: 135 additions & 0 deletions src/core/acc/tx_profiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#ifndef __TX_PROFILER_HPP__
#define __TX_PROFILER_HPP__

#if defined(SIRIUS_TX)
#include <unordered_map>
#include <string>

#if defined(SIRIUS_CUDA)
#include "nvToolsExt.h"
#endif

#if defined(SIRIUS_ROCM)
#include "roctx.h"
#endif

namespace sirius {

namespace acc {

namespace txprofiler {
/* roctx and nvtx ns */

enum class _vendor
{
cuda,
rocm
};

template <enum _vendor>
class TimerVendor
{
};

#if defined(SIRIUS_CUDA)
template <>
class TimerVendor<_vendor::cuda>
{
public:
void
start(std::string const& str)
{
timers_[str] = nvtxRangeStartA(str.c_str());
}

void
stop(std::string const& str)
{
auto result = timers_.find(str);
if (result == timers_.end())
return;
nvtxRangeEnd(result->second);
timers_.erase(result);
}

private:
std::unordered_map<std::string, nvtxRangeId_t> timers_;
};
using timer_vendor_t = TimerVendor<_vendor::cuda>;
#endif /* SIRIUS_CUDA_NVTX */

#if defined(SIRIUS_ROCM)
template <>
class TimerVendor<_vendor::rocm>
{
public:
void
start(std::string const& str)
{
timers_[str] = roctxRangeStartA(str.c_str());
}

void
stop(std::string const& str)
{
auto result = timers_.find(str);
if (result == timers_.end())
return;
roctxRangeStop(result->second);
timers_.erase(result);
}

private:
std::unordered_map<std::string, roctx_range_id_t> timers_;
};
using timer_vendor_t = TimerVendor<_vendor::rocm>;
#endif /* SIRIUS_ROCTX */

class Timer : timer_vendor_t
{
public:
void
start(std::string const& str)
{
timer_vendor_t::start(str);
}

void
stop(std::string const& str)
{
timer_vendor_t::stop(str);
}
};

class ScopedTiming
{
public:
ScopedTiming(std::string identifier, Timer& timer)
: identifier_(identifier)
, timer_(timer)
{
timer.start(identifier_);
}

ScopedTiming(const ScopedTiming&) = delete;
ScopedTiming(ScopedTiming&&) = delete;
auto
operator=(const ScopedTiming&) -> ScopedTiming& = delete;
auto
operator=(ScopedTiming&&) -> ScopedTiming& = delete;

~ScopedTiming()
{
timer_.stop(identifier_);
}

private:
std::string identifier_;
Timer& timer_;
};

} // namespace txprofiler
} // namespace acc
} // namespace sirius
#endif /* defined(SIRIUS_CUDA_NVTX) || defined(SIRIUS_ROCTX) */
#endif /* __TX_PROFILER_HPP__ */
4 changes: 2 additions & 2 deletions src/core/profiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
namespace sirius {
::rt_graph::Timer global_rtgraph_timer;

#if defined(SIRIUS_CUDA_NVTX)
acc::nvtxprofiler::Timer global_nvtx_timer;
#if defined(SIRIUS_TX)
acc::txprofiler::Timer global_tx_timer;
#endif
} // namespace sirius
22 changes: 11 additions & 11 deletions src/core/profiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,40 @@
#include <apex_api.hpp>
#endif
#include "core/rt_graph.hpp"
#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA_NVTX)
#include "core/acc/nvtx_profiler.hpp"
#if defined(SIRIUS_GPU) && defined(SIRIUS_TX)
#include "core/acc/tx_profiler.hpp"
#endif

namespace sirius {

extern ::rt_graph::Timer global_rtgraph_timer;

#if defined(SIRIUS_CUDA_NVTX)
extern acc::nvtxprofiler::Timer global_nvtx_timer;
#if defined(SIRIUS_TX)
extern acc::txprofiler::Timer global_tx_timer;
#endif

// TODO: add calls to apex and cudaNvtx
// TODO: add calls to apex

#if defined(SIRIUS_PROFILE)
#define PROFILER_CONCAT_IMPL(x, y) x##y
#define PROFILER_CONCAT(x, y) PROFILER_CONCAT_IMPL(x, y)

#if defined(SIRIUS_CUDA_NVTX)
#if defined(SIRIUS_TX)
#define PROFILE(identifier) \
acc::nvtxprofiler::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_nvtx_timer); \
acc::txprofiler::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_tx_timer); \
::rt_graph::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_rtgraph_timer);
#define PROFILE_START(identifier) \
global_nvtx_timer.start(identifier); \
global_tx_timer.start(identifier); \
global_rtgraph_timer.start(identifier);
#define PROFILE_STOP(identifier) \
global_rtgraph_timer.stop(identifier); \
global_nvtx_timer.stop(identifier);
#else
global_tx_timer.stop(identifier);
#else /* NVTX and ROCTX are not defined -> just use rt_graph */
#define PROFILE(identifier) \
::rt_graph::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_rtgraph_timer);
#define PROFILE_START(identifier) global_rtgraph_timer.start(identifier);
#define PROFILE_STOP(identifier) global_rtgraph_timer.stop(identifier);
#endif
#endif // SIRIUS_CUDA_NVTX

#else
#define PROFILE(...)
Expand Down