Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: buffer size for rocfft execution info #219

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions fft/src/KokkosFFT_ROCM_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,12 @@ struct ScopedRocfftPlanDescription {
};

/// \brief A class that wraps rocfft_execution_info for RAII
template <typename FloatingPointType>
struct ScopedRocfftExecutionInfo {
private:
using BufferViewType =
Kokkos::View<Kokkos::complex<FloatingPointType> *, Kokkos::HIP>;
rocfft_execution_info m_execution_info;

//! Internal work buffer
BufferViewType m_buffer;
void *m_workbuffer = nullptr;

public:
ScopedRocfftExecutionInfo() {
Expand All @@ -84,6 +81,10 @@ struct ScopedRocfftExecutionInfo {
"rocfft_execution_info_create failed");
}
~ScopedRocfftExecutionInfo() noexcept {
if (m_workbuffer != nullptr) {
hipError_t hip_status = hipFree(m_workbuffer);
if (hip_status != hipSuccess) Kokkos::abort("hipFree failed");
}
rocfft_status status = rocfft_execution_info_destroy(m_execution_info);
if (status != rocfft_status_success)
Kokkos::abort("rocfft_execution_info_destroy failed");
Expand Down Expand Up @@ -111,9 +112,10 @@ struct ScopedRocfftExecutionInfo {

// Set work buffer
if (workbuffersize > 0) {
m_buffer = BufferViewType("workbuffer", workbuffersize);
status = rocfft_execution_info_set_work_buffer(
m_execution_info, (void *)m_buffer.data(), workbuffersize);
hipError_t hip_status = hipMalloc(&m_workbuffer, workbuffersize);
tpadioleau marked this conversation as resolved.
Show resolved Hide resolved
KOKKOSFFT_THROW_IF(hip_status != hipSuccess, "hipMalloc failed");
status = rocfft_execution_info_set_work_buffer(
m_execution_info, m_workbuffer, workbuffersize);
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_execution_info_set_work_buffer failed");
}
Expand All @@ -124,14 +126,12 @@ struct ScopedRocfftExecutionInfo {
template <typename T>
struct ScopedRocfftPlan {
private:
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
using ScopedRocfftExecutionInfoType =
ScopedRocfftExecutionInfo<floating_point_type>;
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T>;
rocfft_precision m_precision = std::is_same_v<floating_point_type, float>
? rocfft_precision_single
: rocfft_precision_double;
rocfft_plan m_plan;
std::unique_ptr<ScopedRocfftExecutionInfoType> m_execution_info;
std::unique_ptr<ScopedRocfftExecutionInfo> m_execution_info;

public:
ScopedRocfftPlan(const FFTWTransformType transform_type,
Expand Down Expand Up @@ -209,7 +209,7 @@ struct ScopedRocfftPlan {
KOKKOSFFT_THROW_IF(status != rocfft_status_success,
"rocfft_plan_get_work_buffer_size failed");

m_execution_info = std::make_unique<ScopedRocfftExecutionInfoType>();
m_execution_info = std::make_unique<ScopedRocfftExecutionInfo>();
m_execution_info->setup(exec_space, workbuffersize);
}

Expand Down
Loading