Skip to content

Commit

Permalink
add nvtx equivalent for rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpintarelli committed Dec 15, 2023
1 parent ef93a25 commit 120e750
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 81 deletions.
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ target_link_libraries(sirius PUBLIC ${GSL_LIBRARY}
SPLA::spla
"${SIRIUS_LINALG_LIB}"
$<$<BOOL:${SIRIUS_USE_MEMORY_POOL}>:umpire>
$<$<BOOL:${SIRIUS_USE_NVTX}>:nvToolsExt>
$<TARGET_NAME_IF_EXISTS:nvToolsExt>
$<TARGET_NAME_IF_EXISTS:sirius::cudalibs>
$<$<BOOL:${SIRIUS_USE_ROCM}>:roc::rocsolver>
$<$<BOOL:${SIRIUS_USE_ROCM}>:roc::rocblas>
Expand All @@ -139,7 +139,7 @@ target_compile_definitions(sirius PUBLIC
$<$<BOOL:${SIRIUS_USE_ELPA}>:SIRIUS_ELPA>
$<$<BOOL:${SIRIUS_USE_NLCGLIB}>:SIRIUS_NLCGLIB>
$<$<BOOL:${SIRIUS_USE_CUDA}>:SIRIUS_GPU SIRIUS_CUDA>
$<$<BOOL:${SIRIUS_USE_NVTX}>:SIRIUS_CUDA_NVTX>
$<$<BOOL:${SIRIUS_USE_NVTX}>:SIRIUS_TX>
$<$<BOOL:${SIRIUS_USE_MAGMA}>:SIRIUS_MAGMA>
$<$<BOOL:${SIRIUS_USE_ROCM}>:SIRIUS_GPU SIRIUS_ROCM>
$<$<BOOL:${SIRIUS_USE_VDWXC}>:SIRIUS_USE_VDWXC>
Expand Down
69 changes: 0 additions & 69 deletions src/core/acc/nvtx_profiler.hpp

This file was deleted.

135 changes: 135 additions & 0 deletions src/core/acc/tx_profiler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#ifndef __TX_PROFILER_HPP__
#define __TX_PROFILER_HPP__

#if defined(SIRIUS_TX)
#include <unordered_map>
#include <string>

#if defined(SIRIUS_CUDA)
#include "nvToolsExt.h"
#endif

#if defined(SIRIUS_ROCM)
#include "roctx.h"
#endif

namespace sirius {

namespace acc {

namespace txprofiler {
/* roctx and nvtx ns */

enum class _vendor
{
cuda,
rocm
};

template <enum _vendor>
class TimerVendor
{
};

#if defined(SIRIUS_CUDA)
template <>
class TimerVendor<_vendor::cuda>
{
public:
void
start(std::string const& str)
{
timers_[str] = nvtxRangeStartA(str.c_str());
}

void
stop(std::string const& str)
{
auto result = timers_.find(str);
if (result == timers_.end())
return;
nvtxRangeEnd(result->second);
timers_.erase(result);
}

private:
std::unordered_map<std::string, nvtxRangeId_t> timers_;
};
using timer_vendor_t = TimerVendor<_vendor::cuda>;
#endif /* SIRIUS_CUDA_NVTX */

#if defined(SIRIUS_ROCM)
template <>
class TimerVendor<_vendor::rocm>
{
public:
void
start(std::string const& str)
{
timers_[str] = roctxRangeStartA(str.c_str());
}

void
stop(std::string const& str)
{
auto result = timers_.find(str);
if (result == timers_.end())
return;
rotcxRangeEnd(result->second);
timers_.erase(result);
}

private:
std::unordered_map<std::string, roctxRangeId_t> timers_;
};
using timer_vendor_t = TimerVendor<_vendor::rocm>;
#endif /* SIRIUS_ROCTX */

class Timer : timer_vendor_t
{
public:
void
start(std::string const& str)
{
timer_vendor_t::start(str);
}

void
stop(std::string const& str)
{
timer_vendor_t::stop(str);
}
};

class ScopedTiming
{
public:
ScopedTiming(std::string identifier, Timer& timer)
: identifier_(identifier)
, timer_(timer)
{
timer.start(identifier_);
}

ScopedTiming(const ScopedTiming&) = delete;
ScopedTiming(ScopedTiming&&) = delete;
auto
operator=(const ScopedTiming&) -> ScopedTiming& = delete;
auto
operator=(ScopedTiming&&) -> ScopedTiming& = delete;

~ScopedTiming()
{
timer_.stop(identifier_);
}

private:
std::string identifier_;
Timer& timer_;
};

} // namespace txprofiler
} // namespace acc
} // namespace sirius
#endif /* defined(SIRIUS_CUDA_NVTX) || defined(SIRIUS_ROCTX) */
#endif /* __TX_PROFILER_HPP__ */
4 changes: 2 additions & 2 deletions src/core/profiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
namespace sirius {
::rt_graph::Timer global_rtgraph_timer;

#if defined(SIRIUS_CUDA_NVTX)
acc::nvtxprofiler::Timer global_nvtx_timer;
#if defined(SIRIUS_TX)
acc::txprofiler::Timer global_tx_timer;
#endif
} // namespace sirius
16 changes: 8 additions & 8 deletions src/core/profiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@
#include <apex_api.hpp>
#endif
#include "core/rt_graph.hpp"
#if defined(SIRIUS_GPU) && defined(SIRIUS_CUDA_NVTX)
#include "core/acc/nvtx_profiler.hpp"
#if defined(SIRIUS_GPU) && defined(SIRIUS_TX)
#include "core/acc/tx_profiler.hpp"
#endif

namespace sirius {

extern ::rt_graph::Timer global_rtgraph_timer;

#if defined(SIRIUS_CUDA_NVTX)
extern acc::nvtxprofiler::Timer global_nvtx_timer;
#if defined(SIRIUS_TX)
extern acc::txprofiler::Timer global_tx_timer;
#endif

// TODO: add calls to apex and cudaNvtx
Expand All @@ -49,22 +49,22 @@ extern acc::nvtxprofiler::Timer global_nvtx_timer;
#define PROFILER_CONCAT_IMPL(x, y) x##y
#define PROFILER_CONCAT(x, y) PROFILER_CONCAT_IMPL(x, y)

#if defined(SIRIUS_CUDA_NVTX)
#if defined(SIRIUS_TX)
#define PROFILE(identifier) \
acc::nvtxprofiler::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_nvtx_timer); \
acc::txprofiler::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_nvtx_timer); \
::rt_graph::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_rtgraph_timer);
#define PROFILE_START(identifier) \
global_nvtx_timer.start(identifier); \
global_rtgraph_timer.start(identifier);
#define PROFILE_STOP(identifier) \
global_rtgraph_timer.stop(identifier); \
global_nvtx_timer.stop(identifier);
#else
#else /* NVTX and ROCTX are not defined -> just use rt_graph */
#define PROFILE(identifier) \
::rt_graph::ScopedTiming PROFILER_CONCAT(GeneratedScopedTimer, __COUNTER__)(identifier, global_rtgraph_timer);
#define PROFILE_START(identifier) global_rtgraph_timer.start(identifier);
#define PROFILE_STOP(identifier) global_rtgraph_timer.stop(identifier);
#endif
#endif // SIRIUS_CUDA_NVTX

#else
#define PROFILE(...)
Expand Down

0 comments on commit 120e750

Please sign in to comment.