Skip to content

Commit

Permalink
Fix RAII issue by introducing wrapper classes for backend plans (#208)
Browse files Browse the repository at this point in the history
* fix: conflicts

* Fix and wrapper for FFTW handle

* Wrapper for cufft handle

* fix: conflicts

* Wrapper for rocfft handle

* fix: conflicts

* Cleanup plan class based on the introduction of wrappers

* fix: conflicts

* fix: conflicts

* fix: unused variable

* fix: work buffer allocation

* remove unused variable

* remove unused lines

* Add missing include header file in KokkosFFT_ROCM_types.hpp

* fix: fftwHandle type in SYCL types

* Do not return const plan type for fftw

* fix: remove const

* fix: fftw plan creation

* fix: set created

* fix: cleanup

* fix constructor of fftw wrapper

* fix: conflicts

* Remove non-default constructors from FFTW wrapper

* Remove non-default constructors from cufft wrapper

* Remove non-default constructors from hipfft wrapper

* Remove non-default constructors from rocfft wrapper

* update FFTW wrapper class name

* fix: host plan type

* fix: fftw rapper name in ROCM_types

* update cuda backed based on reviews

* update hip backend based on reviews

* update rocm backend based on reviews

* update host backend based on revies

* fix: Rocm types

* fix: ROCM types

* fix: Rocm types

* fix: header files

* fix: rocm types

* fix: rocm types

* remove unused lines

* fix: rocm types

* Improve the cleanup logic for cufft plan

* Improve the cleanup logic for hipfft plan

* Improve the cleanup logic for rocfft plan

* simplify fftw plan wrapper

* fix: rocm types

* fix: scoped rocfft plan type

* return execution_info by value in scoped rocfft plan

* Add commit method to scoped cufft plan

* Add commit method to scoped hipfft plan

* Add commit method to scoped rocfft plan

* Add const qualifer for host transforms

* fix: ROCM types

* fix cleanup of ScopedCufft and ScopedHIPfft plan

* Add ScopedExecutionInfo for rocm backend

* fix KokkosFFT_ROCM_types.hpp

* fix: KokkosFFT_ROCM_types.hpp

* make commit method const

* call fftw_cleanup_threads only once

* remove static from init and cleanup methods

* use local static object for global initialization and finalization

* remove cleanup threads for safety

* remove unused header from KokkosFFT_FFTW_Types.hpp

* delete non-default constructors for Rocfft wrappers

* fix: KokkosFFT_ROCM_types.hpp

* Add Thomas as a co-author

Co-authored-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Thomas Padioleau <[email protected]>

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
Co-authored-by: Thomas Padioleau <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 95bd3b3 commit b4b272a
Show file tree
Hide file tree
Showing 16 changed files with 726 additions and 651 deletions.
84 changes: 23 additions & 61 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ namespace KokkosFFT {
namespace Impl {
// 1D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 1 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<1> axes, shape_type<1> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<1> axes, shape_type<1> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -34,37 +33,29 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(nx, type, howmany);
plan->commit(exec_space);

return fft_size;
}

// 2D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 2 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<2> axes, shape_type<2> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<2> axes, shape_type<2> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -75,37 +66,29 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(nx, ny, type);
plan->commit(exec_space);

return fft_size;
}

// 3D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 3 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<3> axes, shape_type<3> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<3> axes, shape_type<3> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -116,7 +99,6 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
Expand All @@ -126,29 +108,22 @@ auto create_plan(const ExecutionSpace& exec_space,
nz = fft_extents.at(2);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(nx, ny, nz, type);
plan->commit(exec_space);

return fft_size;
}

// batched transform, over ND Views
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
std::size_t fft_rank = 1,
typename OutViewType, std::size_t fft_rank = 1,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<fft_rank> axes,
shape_type<fft_rank> s, bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<fft_rank> axes, shape_type<fft_rank> s,
bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -179,27 +154,14 @@ auto create_plan(const ExecutionSpace& exec_space,

// For the moment, considering the contiguous layout only
int istride = 1, ostride = 1;

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlanMany(
&(*plan), rank, fft_extents.data(), in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);

KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(rank, fft_extents.data(), in_extents.data(),
istride, idist, out_extents.data(), ostride,
odist, type, howmany);
plan->commit(exec_space);

return fft_size;
}

template <typename ExecutionSpace, typename PlanType, typename InfoType,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
void destroy_plan_and_info(std::unique_ptr<PlanType>& plan, InfoType&) {
cufftDestroy(*plan);
}
} // namespace Impl
} // namespace KokkosFFT

Expand Down
49 changes: 25 additions & 24 deletions fft/src/KokkosFFT_Cuda_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,49 @@

#include <cufft.h>
#include "KokkosFFT_asserts.hpp"
#include "KokkosFFT_Cuda_types.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);

inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftReal* idata,
cufftComplex* odata, int /*direction*/) {
cufftResult cufft_rt = cufftExecR2C(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecR2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata,
cufftDoubleComplex* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleReal* idata, cufftDoubleComplex* odata,
int /*direction*/) {
cufftResult cufft_rt = cufftExecD2Z(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecD2Z failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata,
cufftReal* odata, int /*direction*/) {
cufftResult cufft_rt = cufftExecC2R(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2R failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleReal* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleComplex* idata, cufftDoubleReal* odata,
int /*direction*/) {
cufftResult cufft_rt = cufftExecZ2D(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2D failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata,
cufftComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata,
cufftComplex* odata, int direction) {
cufftResult cufft_rt =
cufftExecC2C(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleComplex* idata, cufftDoubleComplex* odata,
int direction) {
cufftResult cufft_rt =
cufftExecZ2Z(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2Z failed");
}
} // namespace Impl
Expand Down
95 changes: 66 additions & 29 deletions fft/src/KokkosFFT_Cuda_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
#define KOKKOSFFT_CUDA_TYPES_HPP

#include <cufft.h>
#include <Kokkos_Abort.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_asserts.hpp"

#if defined(ENABLE_HOST_AND_DEVICE)
#include "KokkosFFT_FFTW_Types.hpp"
#endif

// Check the size of complex type
static_assert(sizeof(cufftComplex) == sizeof(Kokkos::complex<float>));
Expand All @@ -15,27 +21,59 @@ static_assert(alignof(cufftComplex) <= alignof(Kokkos::complex<float>));
static_assert(sizeof(cufftDoubleComplex) == sizeof(Kokkos::complex<double>));
static_assert(alignof(cufftDoubleComplex) <= alignof(Kokkos::complex<double>));

#ifdef ENABLE_HOST_AND_DEVICE
#include <fftw3.h>
#include "KokkosFFT_utils.hpp"
static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex<float>));
static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex<float>));

static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex<double>));
static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex<double>));
#endif

namespace KokkosFFT {
namespace Impl {
using FFTDirectionType = int;

// Unused
template <typename ExecutionSpace>
using FFTInfoType = int;
/// \brief A class that wraps cufft for RAII
struct ScopedCufftPlan {
private:
cufftHandle m_plan;

public:
ScopedCufftPlan(int nx, cufftType type, int batch) {
cufftResult cufft_rt = cufftPlan1d(&m_plan, nx, type, batch);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed");
}

ScopedCufftPlan(int nx, int ny, cufftType type) {
cufftResult cufft_rt = cufftPlan2d(&m_plan, nx, ny, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed");
}

#ifdef ENABLE_HOST_AND_DEVICE
enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z };
ScopedCufftPlan(int nx, int ny, int nz, cufftType type) {
cufftResult cufft_rt = cufftPlan3d(&m_plan, nx, ny, nz, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed");
}

ScopedCufftPlan(int rank, int *n, int *inembed, int istride, int idist,
int *onembed, int ostride, int odist, cufftType type,
int batch) {
cufftResult cufft_rt =
cufftPlanMany(&m_plan, rank, n, inembed, istride, idist, onembed,
ostride, odist, type, batch);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed");
}

~ScopedCufftPlan() noexcept {
cufftResult cufft_rt = cufftDestroy(m_plan);
if (cufft_rt != CUFFT_SUCCESS) Kokkos::abort("cufftDestroy failed");
}

ScopedCufftPlan() = delete;
ScopedCufftPlan(const ScopedCufftPlan &) = delete;
ScopedCufftPlan &operator=(const ScopedCufftPlan &) = delete;
ScopedCufftPlan &operator=(ScopedCufftPlan &&) = delete;
ScopedCufftPlan(ScopedCufftPlan &&) = delete;

cufftHandle plan() const noexcept { return m_plan; }
void commit(const Kokkos::Cuda &exec_space) const {
cufftResult cufft_rt = cufftSetStream(m_plan, exec_space.cuda_stream());
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
}
};

#if defined(ENABLE_HOST_AND_DEVICE)
template <typename ExecutionSpace>
struct FFTDataType {
using float32 =
Expand All @@ -52,15 +90,6 @@ struct FFTDataType {
cufftDoubleComplex, fftw_complex>;
};

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftwHandle = std::conditional_t<
std::is_same_v<KokkosFFT::Impl::base_floating_point_type<T1>, float>,
fftwf_plan, fftw_plan>;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
cufftHandle, fftwHandle>;
};

template <typename ExecutionSpace>
using TransformType =
std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>, cufftType,
Expand Down Expand Up @@ -136,6 +165,14 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
}
};

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using fftw_plan_type = ScopedFFTWPlan<ExecutionSpace, T1, T2>;
using cufft_plan_type = ScopedCufftPlan;
using type = std::conditional_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
cufft_plan_type, fftw_plan_type>;
};

template <typename ExecutionSpace>
auto direction_type(Direction direction) {
static constexpr FFTDirectionType FORWARD =
Expand All @@ -155,11 +192,6 @@ struct FFTDataType {
using complex128 = cufftDoubleComplex;
};

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = cufftHandle;
};

template <typename ExecutionSpace>
using TransformType = cufftType;

Expand Down Expand Up @@ -197,6 +229,11 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
static constexpr cufftType type() { return m_type; };
};

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
using type = ScopedCufftPlan;
};

template <typename ExecutionSpace>
auto direction_type(Direction direction) {
return direction == Direction::forward ? CUFFT_FORWARD : CUFFT_INVERSE;
Expand Down
Loading

0 comments on commit b4b272a

Please sign in to comment.