diff --git a/fft/src/KokkosFFT_Cuda_plans.hpp b/fft/src/KokkosFFT_Cuda_plans.hpp index 7f4547f9..dc8423c8 100644 --- a/fft/src/KokkosFFT_Cuda_plans.hpp +++ b/fft/src/KokkosFFT_Cuda_plans.hpp @@ -15,15 +15,14 @@ namespace KokkosFFT { namespace Impl { // 1D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<1> axes, shape_type<1> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<1> axes, shape_type<1> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -34,7 +33,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); auto [in_extents, out_extents, fft_extents, howmany] = @@ -42,29 +40,22 @@ auto create_plan(const ExecutionSpace& exec_space, const int nx = fft_extents.at(0); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - cufftResult cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed"); - - cudaStream_t stream = exec_space.cuda_stream(); - cufft_rt = cufftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed"); + plan = std::make_unique(nx, type, howmany); + plan->commit(exec_space); return fft_size; } // 2D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<2> axes, shape_type<2> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<2> axes, shape_type<2> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -75,7 +66,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); [[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] = @@ -83,29 +73,22 @@ auto create_plan(const ExecutionSpace& exec_space, const int nx = fft_extents.at(0), ny = fft_extents.at(1); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - cufftResult cufft_rt = cufftPlan2d(&(*plan), nx, ny, type); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed"); - - cudaStream_t stream = exec_space.cuda_stream(); - cufft_rt = cufftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed"); + plan = std::make_unique(nx, ny, type); + plan->commit(exec_space); return fft_size; } // 3D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<3> axes, shape_type<3> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<3> axes, shape_type<3> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -116,7 +99,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); [[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] = @@ -126,29 +108,22 @@ auto create_plan(const ExecutionSpace& exec_space, nz = fft_extents.at(2); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - cufftResult cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed"); - - cudaStream_t stream = exec_space.cuda_stream(); - cufft_rt = cufftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed"); + plan = std::make_unique(nx, ny, nz, type); + plan->commit(exec_space); return fft_size; } // batched transform, over ND Views template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type axes, - shape_type s, bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type axes, shape_type s, + bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -179,27 +154,14 @@ auto create_plan(const ExecutionSpace& exec_space, // For the moment, considering the contiguous layout only int istride = 1, ostride = 1; - - plan = std::make_unique(); - cufftResult cufft_rt = cufftPlanMany( - &(*plan), rank, fft_extents.data(), in_extents.data(), istride, idist, - out_extents.data(), ostride, odist, type, howmany); - - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed"); - - cudaStream_t stream = exec_space.cuda_stream(); - cufft_rt = cufftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed"); + plan = std::make_unique(rank, fft_extents.data(), in_extents.data(), + istride, idist, out_extents.data(), ostride, + odist, type, howmany); + plan->commit(exec_space); return fft_size; } -template , - std::nullptr_t> = nullptr> -void destroy_plan_and_info(std::unique_ptr& plan, InfoType&) { - cufftDestroy(*plan); -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_Cuda_transform.hpp b/fft/src/KokkosFFT_Cuda_transform.hpp index 83f0cb45..8a1663da 100644 --- a/fft/src/KokkosFFT_Cuda_transform.hpp +++ b/fft/src/KokkosFFT_Cuda_transform.hpp @@ -7,48 +7,49 @@ #include #include "KokkosFFT_asserts.hpp" +#include "KokkosFFT_Cuda_types.hpp" namespace KokkosFFT { namespace Impl { -template -inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata, - int /*direction*/, Args...) { - cufftResult cufft_rt = cufftExecR2C(plan, idata, odata); + +inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftReal* idata, + cufftComplex* odata, int /*direction*/) { + cufftResult cufft_rt = cufftExecR2C(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecR2C failed"); } -template -inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata, - cufftDoubleComplex* odata, int /*direction*/, Args...) { - cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata); +inline void exec_plan(const ScopedCufftPlan& scoped_plan, + cufftDoubleReal* idata, cufftDoubleComplex* odata, + int /*direction*/) { + cufftResult cufft_rt = cufftExecD2Z(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecD2Z failed"); } -template -inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata, - int /*direction*/, Args...) { - cufftResult cufft_rt = cufftExecC2R(plan, idata, odata); +inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata, + cufftReal* odata, int /*direction*/) { + cufftResult cufft_rt = cufftExecC2R(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2R failed"); } -template -inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata, - cufftDoubleReal* odata, int /*direction*/, Args...) { - cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata); +inline void exec_plan(const ScopedCufftPlan& scoped_plan, + cufftDoubleComplex* idata, cufftDoubleReal* odata, + int /*direction*/) { + cufftResult cufft_rt = cufftExecZ2D(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2D failed"); } -template -inline void exec_plan(cufftHandle& plan, cufftComplex* idata, - cufftComplex* odata, int direction, Args...) { - cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction); +inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata, + cufftComplex* odata, int direction) { + cufftResult cufft_rt = + cufftExecC2C(scoped_plan.plan(), idata, odata, direction); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2C failed"); } -template -inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata, - cufftDoubleComplex* odata, int direction, Args...) { - cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction); +inline void exec_plan(const ScopedCufftPlan& scoped_plan, + cufftDoubleComplex* idata, cufftDoubleComplex* odata, + int direction) { + cufftResult cufft_rt = + cufftExecZ2Z(scoped_plan.plan(), idata, odata, direction); KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2Z failed"); } } // namespace Impl diff --git a/fft/src/KokkosFFT_Cuda_types.hpp b/fft/src/KokkosFFT_Cuda_types.hpp index 52b9d155..f5a1fe62 100644 --- a/fft/src/KokkosFFT_Cuda_types.hpp +++ b/fft/src/KokkosFFT_Cuda_types.hpp @@ -6,7 +6,13 @@ #define KOKKOSFFT_CUDA_TYPES_HPP #include +#include #include "KokkosFFT_common_types.hpp" +#include "KokkosFFT_asserts.hpp" + +#if defined(ENABLE_HOST_AND_DEVICE) +#include "KokkosFFT_FFTW_Types.hpp" +#endif // Check the size of complex type static_assert(sizeof(cufftComplex) == sizeof(Kokkos::complex)); @@ -15,27 +21,59 @@ static_assert(alignof(cufftComplex) <= alignof(Kokkos::complex)); static_assert(sizeof(cufftDoubleComplex) == sizeof(Kokkos::complex)); static_assert(alignof(cufftDoubleComplex) <= alignof(Kokkos::complex)); -#ifdef ENABLE_HOST_AND_DEVICE -#include -#include "KokkosFFT_utils.hpp" -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); -#endif - namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; -// Unused -template -using FFTInfoType = int; +/// \brief A class that wraps cufft for RAII +struct ScopedCufftPlan { + private: + cufftHandle m_plan; + + public: + ScopedCufftPlan(int nx, cufftType type, int batch) { + cufftResult cufft_rt = cufftPlan1d(&m_plan, nx, type, batch); + KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed"); + } + + ScopedCufftPlan(int nx, int ny, cufftType type) { + cufftResult cufft_rt = cufftPlan2d(&m_plan, nx, ny, type); + KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed"); + } -#ifdef ENABLE_HOST_AND_DEVICE -enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; + ScopedCufftPlan(int nx, int ny, int nz, cufftType type) { + cufftResult cufft_rt = cufftPlan3d(&m_plan, nx, ny, nz, type); + KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed"); + } + + ScopedCufftPlan(int rank, int *n, int *inembed, int istride, int idist, + int *onembed, int ostride, int odist, cufftType type, + int batch) { + cufftResult cufft_rt = + cufftPlanMany(&m_plan, rank, n, inembed, istride, idist, onembed, + ostride, odist, type, batch); + KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed"); + } + + ~ScopedCufftPlan() noexcept { + cufftResult cufft_rt = cufftDestroy(m_plan); + if (cufft_rt != CUFFT_SUCCESS) Kokkos::abort("cufftDestroy failed"); + } + + ScopedCufftPlan() = delete; + ScopedCufftPlan(const ScopedCufftPlan &) = delete; + ScopedCufftPlan &operator=(const ScopedCufftPlan &) = delete; + ScopedCufftPlan &operator=(ScopedCufftPlan &&) = delete; + ScopedCufftPlan(ScopedCufftPlan &&) = delete; + + cufftHandle plan() const noexcept { return m_plan; } + void commit(const Kokkos::Cuda &exec_space) const { + cufftResult cufft_rt = cufftSetStream(m_plan, exec_space.cuda_stream()); + KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed"); + } +}; +#if defined(ENABLE_HOST_AND_DEVICE) template struct FFTDataType { using float32 = @@ -52,15 +90,6 @@ struct FFTDataType { cufftDoubleComplex, fftw_complex>; }; -template -struct FFTPlanType { - using fftwHandle = std::conditional_t< - std::is_same_v, float>, - fftwf_plan, fftw_plan>; - using type = std::conditional_t, - cufftHandle, fftwHandle>; -}; - template using TransformType = std::conditional_t, cufftType, @@ -136,6 +165,14 @@ struct transform_type, } }; +template +struct FFTPlanType { + using fftw_plan_type = ScopedFFTWPlan; + using cufft_plan_type = ScopedCufftPlan; + using type = std::conditional_t, + cufft_plan_type, fftw_plan_type>; +}; + template auto direction_type(Direction direction) { static constexpr FFTDirectionType FORWARD = @@ -155,11 +192,6 @@ struct FFTDataType { using complex128 = cufftDoubleComplex; }; -template -struct FFTPlanType { - using type = cufftHandle; -}; - template using TransformType = cufftType; @@ -197,6 +229,11 @@ struct transform_type, static constexpr cufftType type() { return m_type; }; }; +template +struct FFTPlanType { + using type = ScopedCufftPlan; +}; + template auto direction_type(Direction direction) { return direction == Direction::forward ? CUFFT_FORWARD : CUFFT_INVERSE; diff --git a/fft/src/KokkosFFT_FFTW_Types.hpp b/fft/src/KokkosFFT_FFTW_Types.hpp new file mode 100644 index 00000000..8f686577 --- /dev/null +++ b/fft/src/KokkosFFT_FFTW_Types.hpp @@ -0,0 +1,144 @@ +// SPDX-FileCopyrightText: (C) The kokkos-fft development team, see COPYRIGHT.md file +// +// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception + +#ifndef KOKKOSFFT_FFTW_TYPES_HPP +#define KOKKOSFFT_FFTW_TYPES_HPP + +#include +#include +#include "KokkosFFT_common_types.hpp" +#include "KokkosFFT_utils.hpp" + +// Check the size of complex type +static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); +static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); + +static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); +static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); + +namespace KokkosFFT { +namespace Impl { + +enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; + +// Define fft transform types +template +struct fftw_transform_type { + static_assert(std::is_same_v, + "Real to real transform is unavailable"); +}; + +template +struct fftw_transform_type> { + static_assert(std::is_same_v, + "T1 and T2 should have the same precision"); + static constexpr FFTWTransformType m_type = std::is_same_v + ? FFTWTransformType::R2C + : FFTWTransformType::D2Z; + static constexpr FFTWTransformType type() { return m_type; }; +}; + +template +struct fftw_transform_type, T2> { + static_assert(std::is_same_v, + "T1 and T2 should have the same precision"); + static constexpr FFTWTransformType m_type = std::is_same_v + ? FFTWTransformType::C2R + : FFTWTransformType::Z2D; + static constexpr FFTWTransformType type() { return m_type; }; +}; + +template +struct fftw_transform_type, Kokkos::complex> { + static_assert(std::is_same_v, + "T1 and T2 should have the same precision"); + static constexpr FFTWTransformType m_type = std::is_same_v + ? FFTWTransformType::C2C + : FFTWTransformType::Z2Z; + static constexpr FFTWTransformType type() { return m_type; }; +}; + +/// \brief A class that wraps fftw_plan and fftwf_plan for RAII +template +struct ScopedFFTWPlan { + private: + using floating_point_type = KokkosFFT::Impl::base_floating_point_type; + using plan_type = + std::conditional_t, fftwf_plan, + fftw_plan>; + plan_type m_plan; + + public: + template + ScopedFFTWPlan(const ExecutionSpace &exec_space, int rank, const int *n, + int howmany, InScalarType *in, const int *inembed, int istride, + int idist, OutScalarType *out, const int *onembed, int ostride, + int odist, [[maybe_unused]] int sign, unsigned flags) { + init_threads(exec_space); + constexpr auto type = fftw_transform_type::type(); + if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { + m_plan = + fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, flags); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) { + m_plan = + fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, flags); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) { + m_plan = + fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, flags); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) { + m_plan = + fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, flags); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) { + m_plan = + fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, sign, flags); + } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) { + m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, sign, flags); + } + } + + ~ScopedFFTWPlan() noexcept { + if constexpr (std::is_same_v) { + fftwf_destroy_plan(m_plan); + } else { + fftw_destroy_plan(m_plan); + } + } + + ScopedFFTWPlan() = delete; + ScopedFFTWPlan(const ScopedFFTWPlan &) = delete; + ScopedFFTWPlan &operator=(const ScopedFFTWPlan &) = delete; + ScopedFFTWPlan &operator=(ScopedFFTWPlan &&) = delete; + ScopedFFTWPlan(ScopedFFTWPlan &&) = delete; + + plan_type plan() const noexcept { return m_plan; } + + private: + void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) { +#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) + if constexpr (std::is_same_v) { + int nthreads = exec_space.concurrency(); + + if constexpr (std::is_same_v) { + fftwf_init_threads(); + fftwf_plan_with_nthreads(nthreads); + } else { + fftw_init_threads(); + fftw_plan_with_nthreads(nthreads); + } + } +#endif + } +}; + +} // namespace Impl +} // namespace KokkosFFT + +#endif diff --git a/fft/src/KokkosFFT_HIP_plans.hpp b/fft/src/KokkosFFT_HIP_plans.hpp index 1e0ba043..9859cd44 100644 --- a/fft/src/KokkosFFT_HIP_plans.hpp +++ b/fft/src/KokkosFFT_HIP_plans.hpp @@ -15,15 +15,14 @@ namespace KokkosFFT { namespace Impl { // 1D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<1> axes, shape_type<1> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<1> axes, shape_type<1> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -34,7 +33,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); auto [in_extents, out_extents, fft_extents, howmany] = @@ -42,29 +40,22 @@ auto create_plan(const ExecutionSpace& exec_space, const int nx = fft_extents.at(0); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - hipfftResult hipfft_rt = hipfftPlan1d(&(*plan), nx, type, howmany); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan1d failed"); - - hipStream_t stream = exec_space.hip_stream(); - hipfft_rt = hipfftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed"); + plan = std::make_unique(nx, type, howmany); + plan->commit(exec_space); return fft_size; } // 2D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<2> axes, shape_type<2> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<2> axes, shape_type<2> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -75,7 +66,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); [[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] = @@ -83,29 +73,22 @@ auto create_plan(const ExecutionSpace& exec_space, const int nx = fft_extents.at(0), ny = fft_extents.at(1); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - hipfftResult hipfft_rt = hipfftPlan2d(&(*plan), nx, ny, type); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan2d failed"); - - hipStream_t stream = exec_space.hip_stream(); - hipfft_rt = hipfftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed"); + plan = std::make_unique(nx, ny, type); + plan->commit(exec_space); return fft_size; } // 3D transform template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type<3> axes, shape_type<3> s, - bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type<3> axes, shape_type<3> s, bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -116,7 +99,6 @@ auto create_plan(const ExecutionSpace& exec_space, "InViewType and OutViewType."); using in_value_type = typename InViewType::non_const_value_type; using out_value_type = typename OutViewType::non_const_value_type; - auto type = KokkosFFT::Impl::transform_type::type(); [[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] = @@ -126,29 +108,22 @@ auto create_plan(const ExecutionSpace& exec_space, nz = fft_extents.at(2); int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, std::multiplies<>()); - - plan = std::make_unique(); - hipfftResult hipfft_rt = hipfftPlan3d(&(*plan), nx, ny, nz, type); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan3d failed"); - - hipStream_t stream = exec_space.hip_stream(); - hipfft_rt = hipfftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed"); + plan = std::make_unique(nx, ny, nz, type); + plan->commit(exec_space); return fft_size; } // batched transform, over ND Views template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type axes, - shape_type s, bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type axes, shape_type s, + bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -179,27 +154,14 @@ auto create_plan(const ExecutionSpace& exec_space, // For the moment, considering the contiguous layout only int istride = 1, ostride = 1; - - plan = std::make_unique(); - hipfftResult hipfft_rt = hipfftPlanMany( - &(*plan), rank, fft_extents.data(), in_extents.data(), istride, idist, - out_extents.data(), ostride, odist, type, howmany); - - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlanMany failed"); - - hipStream_t stream = exec_space.hip_stream(); - hipfft_rt = hipfftSetStream((*plan), stream); - KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed"); + plan = std::make_unique(rank, fft_extents.data(), in_extents.data(), + istride, idist, out_extents.data(), ostride, + odist, type, howmany); + plan->commit(exec_space); return fft_size; } -template , - std::nullptr_t> = nullptr> -void destroy_plan_and_info(std::unique_ptr& plan, InfoType&) { - hipfftDestroy(*plan); -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_HIP_transform.hpp b/fft/src/KokkosFFT_HIP_transform.hpp index 6e131150..ba103afc 100644 --- a/fft/src/KokkosFFT_HIP_transform.hpp +++ b/fft/src/KokkosFFT_HIP_transform.hpp @@ -7,48 +7,49 @@ #include #include "KokkosFFT_asserts.hpp" +#include "KokkosFFT_HIP_types.hpp" namespace KokkosFFT { namespace Impl { -template -inline void exec_plan(hipfftHandle& plan, hipfftReal* idata, - hipfftComplex* odata, int /*direction*/, Args...) { - hipfftResult hipfft_rt = hipfftExecR2C(plan, idata, odata); + +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftReal* idata, + hipfftComplex* odata, int /*direction*/) { + hipfftResult hipfft_rt = hipfftExecR2C(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecR2C failed"); } -template -inline void exec_plan(hipfftHandle& plan, hipfftDoubleReal* idata, - hipfftDoubleComplex* odata, int /*direction*/, Args...) { - hipfftResult hipfft_rt = hipfftExecD2Z(plan, idata, odata); +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, + hipfftDoubleReal* idata, hipfftDoubleComplex* odata, + int /*direction*/) { + hipfftResult hipfft_rt = hipfftExecD2Z(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecD2Z failed"); } -template -inline void exec_plan(hipfftHandle& plan, hipfftComplex* idata, - hipfftReal* odata, int /*direction*/, Args...) { - hipfftResult hipfft_rt = hipfftExecC2R(plan, idata, odata); +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftComplex* idata, + hipfftReal* odata, int /*direction*/) { + hipfftResult hipfft_rt = hipfftExecC2R(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecC2R failed"); } -template -inline void exec_plan(hipfftHandle& plan, hipfftDoubleComplex* idata, - hipfftDoubleReal* odata, int /*direction*/, Args...) { - hipfftResult hipfft_rt = hipfftExecZ2D(plan, idata, odata); +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, + hipfftDoubleComplex* idata, hipfftDoubleReal* odata, + int /*direction*/) { + hipfftResult hipfft_rt = hipfftExecZ2D(scoped_plan.plan(), idata, odata); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecZ2D failed"); } -template -inline void exec_plan(hipfftHandle& plan, hipfftComplex* idata, - hipfftComplex* odata, int direction, Args...) { - hipfftResult hipfft_rt = hipfftExecC2C(plan, idata, odata, direction); +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, hipfftComplex* idata, + hipfftComplex* odata, int direction) { + hipfftResult hipfft_rt = + hipfftExecC2C(scoped_plan.plan(), idata, odata, direction); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecC2C failed"); } -template -inline void exec_plan(hipfftHandle& plan, hipfftDoubleComplex* idata, - hipfftDoubleComplex* odata, int direction, Args...) { - hipfftResult hipfft_rt = hipfftExecZ2Z(plan, idata, odata, direction); +inline void exec_plan(const ScopedHIPfftPlan& scoped_plan, + hipfftDoubleComplex* idata, hipfftDoubleComplex* odata, + int direction) { + hipfftResult hipfft_rt = + hipfftExecZ2Z(scoped_plan.plan(), idata, odata, direction); KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftExecZ2Z failed"); } } // namespace Impl diff --git a/fft/src/KokkosFFT_HIP_types.hpp b/fft/src/KokkosFFT_HIP_types.hpp index 38663905..4be4d397 100644 --- a/fft/src/KokkosFFT_HIP_types.hpp +++ b/fft/src/KokkosFFT_HIP_types.hpp @@ -6,7 +6,13 @@ #define KOKKOSFFT_HIP_TYPES_HPP #include +#include #include "KokkosFFT_common_types.hpp" +#include "KokkosFFT_asserts.hpp" + +#if defined(ENABLE_HOST_AND_DEVICE) +#include "KokkosFFT_FFTW_Types.hpp" +#endif // Check the size of complex type static_assert(sizeof(hipfftComplex) == sizeof(Kokkos::complex)); @@ -15,27 +21,59 @@ static_assert(alignof(hipfftComplex) <= alignof(Kokkos::complex)); static_assert(sizeof(hipfftDoubleComplex) == sizeof(Kokkos::complex)); static_assert(alignof(hipfftDoubleComplex) <= alignof(Kokkos::complex)); -#ifdef ENABLE_HOST_AND_DEVICE -#include -#include "KokkosFFT_utils.hpp" -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); -#endif - namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; -// Unused -template -using FFTInfoType = int; +/// \brief A class that wraps hipfft for RAII +struct ScopedHIPfftPlan { + private: + hipfftHandle m_plan; + + public: + ScopedHIPfftPlan(int nx, hipfftType type, int batch) { + hipfftResult hipfft_rt = hipfftPlan1d(&m_plan, nx, type, batch); + KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan1d failed"); + } + + ScopedHIPfftPlan(int nx, int ny, hipfftType type) { + hipfftResult hipfft_rt = hipfftPlan2d(&m_plan, nx, ny, type); + KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan2d failed"); + } -#ifdef ENABLE_HOST_AND_DEVICE -enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; + ScopedHIPfftPlan(int nx, int ny, int nz, hipfftType type) { + hipfftResult hipfft_rt = hipfftPlan3d(&m_plan, nx, ny, nz, type); + KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlan3d failed"); + } + + ScopedHIPfftPlan(int rank, int *n, int *inembed, int istride, int idist, + int *onembed, int ostride, int odist, hipfftType type, + int batch) { + hipfftResult hipfft_rt = + hipfftPlanMany(&m_plan, rank, n, inembed, istride, idist, onembed, + ostride, odist, type, batch); + KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftPlanMany failed"); + } + + ~ScopedHIPfftPlan() noexcept { + hipfftResult hipfft_rt = hipfftDestroy(m_plan); + if (hipfft_rt != HIPFFT_SUCCESS) Kokkos::abort("hipfftDestroy failed"); + } + + ScopedHIPfftPlan() = delete; + ScopedHIPfftPlan(const ScopedHIPfftPlan &) = delete; + ScopedHIPfftPlan &operator=(const ScopedHIPfftPlan &) = delete; + ScopedHIPfftPlan &operator=(ScopedHIPfftPlan &&) = delete; + ScopedHIPfftPlan(ScopedHIPfftPlan &&) = delete; + + hipfftHandle plan() const noexcept { return m_plan; } + void commit(const Kokkos::HIP &exec_space) const { + hipfftResult hipfft_rt = hipfftSetStream(m_plan, exec_space.hip_stream()); + KOKKOSFFT_THROW_IF(hipfft_rt != HIPFFT_SUCCESS, "hipfftSetStream failed"); + } +}; +#if defined(ENABLE_HOST_AND_DEVICE) template struct FFTDataType { using float32 = @@ -52,15 +90,6 @@ struct FFTDataType { hipfftDoubleComplex, fftw_complex>; }; -template -struct FFTPlanType { - using fftwHandle = std::conditional_t< - std::is_same_v, float>, - fftwf_plan, fftw_plan>; - using type = std::conditional_t, - hipfftHandle, fftwHandle>; -}; - template using TransformType = std::conditional_t, hipfftType, @@ -136,6 +165,14 @@ struct transform_type, } }; +template +struct FFTPlanType { + using fftw_plan_type = ScopedFFTWPlan; + using hipfft_plan_type = ScopedHIPfftPlan; + using type = std::conditional_t, + hipfft_plan_type, fftw_plan_type>; +}; + template auto direction_type(Direction direction) { static constexpr FFTDirectionType FORWARD = @@ -155,11 +192,6 @@ struct FFTDataType { using complex128 = hipfftDoubleComplex; }; -template -struct FFTPlanType { - using type = hipfftHandle; -}; - template using TransformType = hipfftType; @@ -197,6 +229,11 @@ struct transform_type, static constexpr hipfftType type() { return m_type; }; }; +template +struct FFTPlanType { + using type = ScopedHIPfftPlan; +}; + template auto direction_type(Direction direction) { return direction == Direction::forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD; diff --git a/fft/src/KokkosFFT_Host_plans.hpp b/fft/src/KokkosFFT_Host_plans.hpp index 7b66522e..00042f74 100644 --- a/fft/src/KokkosFFT_Host_plans.hpp +++ b/fft/src/KokkosFFT_Host_plans.hpp @@ -12,33 +12,16 @@ namespace KokkosFFT { namespace Impl { - -template -void init_threads([[maybe_unused]] const ExecutionSpace& exec_space) { -#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS) - int nthreads = exec_space.concurrency(); - - if constexpr (std::is_same_v) { - fftwf_init_threads(); - fftwf_plan_with_nthreads(nthreads); - } else { - fftw_init_threads(); - fftw_plan_with_nthreads(nthreads); - } -#endif -} - // batched transform, over ND Views template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction direction, axis_type axes, - shape_type s, bool is_inplace) { + const OutViewType& out, Direction direction, + axis_type axes, shape_type s, + bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -56,13 +39,6 @@ auto create_plan(const ExecutionSpace& exec_space, using out_value_type = typename OutViewType::non_const_value_type; const int rank = fft_rank; - init_threads>( - exec_space); - - constexpr auto type = - KokkosFFT::Impl::transform_type::type(); auto [in_extents, out_extents, fft_extents, howmany] = KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace); int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1, @@ -82,46 +58,14 @@ auto create_plan(const ExecutionSpace& exec_space, [[maybe_unused]] auto sign = KokkosFFT::Impl::direction_type(direction); - plan = std::make_unique(); - if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) { - *plan = fftwf_plan_many_dft_r2c( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) { - *plan = fftw_plan_many_dft_r2c( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) { - *plan = fftwf_plan_many_dft_c2r( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) { - *plan = fftw_plan_many_dft_c2r( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) { - *plan = fftwf_plan_many_dft( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE); - } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) { - *plan = fftw_plan_many_dft( - rank, fft_extents.data(), howmany, idata, in_extents.data(), istride, - idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE); - } + plan = std::make_unique(exec_space, rank, fft_extents.data(), + howmany, idata, in_extents.data(), istride, + idist, odata, out_extents.data(), ostride, + odist, sign, FFTW_ESTIMATE); return fft_size; } -template , std::nullptr_t> = - nullptr> -void destroy_plan_and_info(std::unique_ptr& plan, InfoType&) { - if constexpr (std::is_same_v) { - fftwf_destroy_plan(*plan); - } else { - fftw_destroy_plan(*plan); - } -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_Host_transform.hpp b/fft/src/KokkosFFT_Host_transform.hpp index 4dfc04bb..33db513e 100644 --- a/fft/src/KokkosFFT_Host_transform.hpp +++ b/fft/src/KokkosFFT_Host_transform.hpp @@ -9,40 +9,41 @@ namespace KokkosFFT { namespace Impl { -template -void exec_plan(PlanType& plan, float* idata, fftwf_complex* odata, - int /*direction*/, Args...) { - fftwf_execute_dft_r2c(plan, idata, odata); + +template +void exec_plan(const ScopedPlanType& scoped_plan, float* idata, + fftwf_complex* odata, int /*direction*/) { + fftwf_execute_dft_r2c(scoped_plan.plan(), idata, odata); } -template -void exec_plan(PlanType& plan, double* idata, fftw_complex* odata, - int /*direction*/, Args...) { - fftw_execute_dft_r2c(plan, idata, odata); +template +void exec_plan(const ScopedPlanType& scoped_plan, double* idata, + fftw_complex* odata, int /*direction*/) { + fftw_execute_dft_r2c(scoped_plan.plan(), idata, odata); } -template -void exec_plan(PlanType& plan, fftwf_complex* idata, float* odata, - int /*direction*/, Args...) { - fftwf_execute_dft_c2r(plan, idata, odata); +template +void exec_plan(const ScopedPlanType& scoped_plan, fftwf_complex* idata, + float* odata, int /*direction*/) { + fftwf_execute_dft_c2r(scoped_plan.plan(), idata, odata); } -template -void exec_plan(PlanType& plan, fftw_complex* idata, double* odata, - int /*direction*/, Args...) { - fftw_execute_dft_c2r(plan, idata, odata); +template +void exec_plan(const ScopedPlanType& scoped_plan, fftw_complex* idata, + double* odata, int /*direction*/) { + fftw_execute_dft_c2r(scoped_plan.plan(), idata, odata); } -template -void exec_plan(PlanType& plan, fftwf_complex* idata, fftwf_complex* odata, - int /*direction*/, Args...) { - fftwf_execute_dft(plan, idata, odata); +template +void exec_plan(const ScopedPlanType& scoped_plan, fftwf_complex* idata, + fftwf_complex* odata, int /*direction*/) { + fftwf_execute_dft(scoped_plan.plan(), idata, odata); } -template -void exec_plan(PlanType plan, fftw_complex* idata, fftw_complex* odata, - int /*direction*/, Args...) { - fftw_execute_dft(plan, idata, odata); +template +void exec_plan(const ScopedPlanType& scoped_plan, fftw_complex* idata, + fftw_complex* odata, int /*direction*/) { + fftw_execute_dft(scoped_plan.plan(), idata, odata); } } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_Host_types.hpp b/fft/src/KokkosFFT_Host_types.hpp index a8a24d75..85e754b2 100644 --- a/fft/src/KokkosFFT_Host_types.hpp +++ b/fft/src/KokkosFFT_Host_types.hpp @@ -5,27 +5,12 @@ #ifndef KOKKOSFFT_HOST_TYPES_HPP #define KOKKOSFFT_HOST_TYPES_HPP -#include -#include "KokkosFFT_common_types.hpp" -#include "KokkosFFT_utils.hpp" - -// Check the size of complex type -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); +#include "KokkosFFT_FFTW_Types.hpp" namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; -// Unused -template -using FFTInfoType = int; - -enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; - template struct FFTDataType { using float32 = float; @@ -34,52 +19,15 @@ struct FFTDataType { using complex128 = fftw_complex; }; -template -struct FFTPlanType { - using type = std::conditional_t< - std::is_same_v, float>, - fftwf_plan, fftw_plan>; -}; - template using TransformType = FFTWTransformType; -// Define fft transform types template -struct transform_type { - static_assert(std::is_same_v, - "Real to real transform is unavailable"); -}; - -template -struct transform_type> { - static_assert(std::is_same_v, - "T1 and T2 should have the same precision"); - static constexpr FFTWTransformType m_type = std::is_same_v - ? FFTWTransformType::R2C - : FFTWTransformType::D2Z; - static constexpr FFTWTransformType type() { return m_type; }; -}; +using transform_type = fftw_transform_type; template -struct transform_type, T2> { - static_assert(std::is_same_v, - "T1 and T2 should have the same precision"); - static constexpr FFTWTransformType m_type = std::is_same_v - ? FFTWTransformType::C2R - : FFTWTransformType::Z2D; - static constexpr FFTWTransformType type() { return m_type; }; -}; - -template -struct transform_type, - Kokkos::complex> { - static_assert(std::is_same_v, - "T1 and T2 should have the same precision"); - static constexpr FFTWTransformType m_type = std::is_same_v - ? FFTWTransformType::C2C - : FFTWTransformType::Z2Z; - static constexpr FFTWTransformType type() { return m_type; }; +struct FFTPlanType { + using type = ScopedFFTWPlan; }; template diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 3bba0fb2..7f6b98af 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -5,7 +5,7 @@ /// \file KokkosFFT_Plans.hpp /// \brief Wrapping fft plans of different fft libraries /// -/// This file provides KokkosFFT::Impl::Plan. +/// This file provides KokkosFFT::Plan. /// This implements a local (no MPI) interface for fft plans #ifndef KOKKOSFFT_PLANS_HPP @@ -22,7 +22,7 @@ #if defined(KOKKOS_ENABLE_CUDA) #include "KokkosFFT_Cuda_plans.hpp" #include "KokkosFFT_Cuda_transform.hpp" -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) #include "KokkosFFT_Host_plans.hpp" #include "KokkosFFT_Host_transform.hpp" #endif @@ -34,14 +34,14 @@ #include "KokkosFFT_HIP_plans.hpp" #include "KokkosFFT_HIP_transform.hpp" #endif -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) #include "KokkosFFT_Host_plans.hpp" #include "KokkosFFT_Host_transform.hpp" #endif #elif defined(KOKKOS_ENABLE_SYCL) #include "KokkosFFT_SYCL_plans.hpp" #include "KokkosFFT_SYCL_transform.hpp" -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) #include "KokkosFFT_Host_plans.hpp" #include "KokkosFFT_Host_transform.hpp" #endif @@ -88,19 +88,12 @@ class Plan { typename KokkosFFT::Impl::FFTPlanType::type; - //! The type of fft info (used for rocfft only) - using fft_info_type = typename KokkosFFT::Impl::FFTInfoType; - //! The type of fft size using fft_size_type = std::size_t; //! The type of map for transpose using map_type = axis_type; - //! Naive 1D View for work buffer - using BufferViewType = - Kokkos::View*, layout_type, execSpace>; - //! The type of extents of input/output views using extents_type = shape_type; @@ -111,9 +104,6 @@ class Plan { //! Dynamically allocatable fft plan. std::unique_ptr m_plan; - //! fft info - fft_info_type m_info; - //! fft size fft_size_type m_fft_size = 1; @@ -143,9 +133,6 @@ class Plan { extents_type m_in_extents, m_out_extents; ///@} - //! Internal work buffer (for rocfft) - BufferViewType m_buffer; - public: /// \brief Constructor /// @@ -209,9 +196,9 @@ class Plan { KOKKOSFFT_THROW_IF(m_is_inplace && m_is_crop_or_pad_needed, "In-place transform is not supported with reshape. " "Please use out-of-place transform."); - m_fft_size = KokkosFFT::Impl::create_plan(exec_space, m_plan, in, out, - m_buffer, m_info, direction, - m_axes, s, m_is_inplace); + + m_fft_size = KokkosFFT::Impl::create_plan( + exec_space, m_plan, in, out, direction, m_axes, s, m_is_inplace); } /// \brief Constructor for multidimensional FFT @@ -272,16 +259,13 @@ class Plan { KOKKOSFFT_THROW_IF(m_is_inplace && m_is_crop_or_pad_needed, "In-place transform is not supported with reshape. " "Please use out-of-place transform."); - m_fft_size = - KokkosFFT::Impl::create_plan(exec_space, m_plan, in, out, m_buffer, - m_info, direction, axes, s, m_is_inplace); - } - ~Plan() { - KokkosFFT::Impl::destroy_plan_and_info(m_plan, m_info); + m_fft_size = KokkosFFT::Impl::create_plan(exec_space, m_plan, in, out, + direction, axes, s, m_is_inplace); } + ~Plan() noexcept = default; + Plan() = delete; Plan(const Plan&) = delete; Plan& operator=(const Plan&) = delete; @@ -358,7 +342,7 @@ class Plan { auto const direction = KokkosFFT::Impl::direction_type(m_direction); - KokkosFFT::Impl::exec_plan(*m_plan, idata, odata, direction, m_info); + KokkosFFT::Impl::exec_plan(*m_plan, idata, odata, direction); if constexpr (KokkosFFT::Impl::is_complex_v && KokkosFFT::Impl::is_real_v) { diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index e1b115e9..56c1d6c3 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -5,8 +5,6 @@ #ifndef KOKKOSFFT_ROCM_PLANS_HPP #define KOKKOSFFT_ROCM_PLANS_HPP -#include -#include #include "KokkosFFT_ROCM_types.hpp" #include "KokkosFFT_Extents.hpp" #include "KokkosFFT_traits.hpp" @@ -15,83 +13,15 @@ namespace KokkosFFT { namespace Impl { -// Helper to get input and output array type and direction from transform type -template -auto get_in_out_array_type(TransformType type, Direction direction) { - rocfft_array_type in_array_type, out_array_type; - rocfft_transform_type fft_direction; - - if (type == FFTWTransformType::C2C || type == FFTWTransformType::Z2Z) { - in_array_type = rocfft_array_type_complex_interleaved; - out_array_type = rocfft_array_type_complex_interleaved; - fft_direction = direction == Direction::forward - ? rocfft_transform_type_complex_forward - : rocfft_transform_type_complex_inverse; - } else if (type == FFTWTransformType::R2C || type == FFTWTransformType::D2Z) { - in_array_type = rocfft_array_type_real; - out_array_type = rocfft_array_type_hermitian_interleaved; - fft_direction = rocfft_transform_type_real_forward; - } else if (type == FFTWTransformType::C2R || type == FFTWTransformType::Z2D) { - in_array_type = rocfft_array_type_hermitian_interleaved; - out_array_type = rocfft_array_type_real; - fft_direction = rocfft_transform_type_real_inverse; - } - - return std::tuple( - {in_array_type, out_array_type, fft_direction}); -}; - -template -rocfft_precision get_in_out_array_type() { - return std::is_same_v, - float> - ? rocfft_precision_single - : rocfft_precision_double; -} - -// Helper to convert the integer type of vectors -template -auto convert_int_type_and_reverse(std::vector& in) - -> std::vector { - std::vector out(in.size()); - std::transform( - in.begin(), in.end(), out.begin(), - [](const InType v) -> OutType { return static_cast(v); }); - - std::reverse(out.begin(), out.end()); - return out; -} - -// Helper to compute strides from extents -// (n0, n1, n2) -> (1, n0, n0*n1) -// (n0, n1) -> (1, n0) -// (n0) -> (1) -template -auto compute_strides(const std::vector& extents) - -> std::vector { - std::vector out = {1}; - auto reversed_extents = extents; - std::reverse(reversed_extents.begin(), reversed_extents.end()); - - for (std::size_t i = 1; i < reversed_extents.size(); i++) { - out.push_back(static_cast(reversed_extents.at(i - 1)) * - out.at(i - 1)); - } - - return out; -} // batched transform, over ND Views template , std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType& buffer, - InfoType& execution_info, Direction direction, + const OutViewType& out, Direction direction, axis_type axes, shape_type s, bool is_inplace) { static_assert( @@ -114,101 +44,19 @@ auto create_plan(const ExecutionSpace& exec_space, out_value_type>::type(); auto [in_extents, out_extents, fft_extents, howmany] = KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace); - int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1, - std::multiplies<>()); - int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1, - std::multiplies<>()); - int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, - std::multiplies<>()); - - // For the moment, considering the contiguous layout only - // Create plan - auto in_strides = compute_strides(in_extents); - auto out_strides = compute_strides(out_extents); - auto reversed_fft_extents = - convert_int_type_and_reverse(fft_extents); - - // Create the description - rocfft_plan_description description; - rocfft_status status = rocfft_plan_description_create(&description); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_description_create failed"); - - auto [in_array_type, out_array_type, fft_direction] = - get_in_out_array_type(type, direction); - rocfft_precision precision = get_in_out_array_type(); - - status = rocfft_plan_description_set_data_layout( - description, // description handle - in_array_type, // input array type - out_array_type, // output array type - nullptr, // offsets to start of input data - nullptr, // offsets to start of output data - in_strides.size(), // input stride length - in_strides.data(), // input stride data - idist, // input batch distance - out_strides.size(), // output stride length - out_strides.data(), // output stride data - odist); // output batch distance - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_description_set_data_layout failed"); - - // Out-of-place transform - const rocfft_result_placement place = - is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace; // Create a plan - plan = std::make_unique(); - status = rocfft_plan_create(&(*plan), place, fft_direction, precision, - reversed_fft_extents.size(), // Dimension - reversed_fft_extents.data(), // Lengths - howmany, // Number of transforms - description // Description - ); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_create failed"); - - // Prepare workbuffer and set execution information - status = rocfft_execution_info_create(&execution_info); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_execution_info_create failed"); - - // set stream - // NOTE: The stream must be of type hipStream_t. - // It is an error to pass the address of a hipStream_t object. - hipStream_t stream = exec_space.hip_stream(); - status = rocfft_execution_info_set_stream(execution_info, stream); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_execution_info_set_stream failed"); + plan = std::make_unique(type, in_extents, out_extents, fft_extents, + howmany, direction, is_inplace); + plan->commit(exec_space); - std::size_t workbuffersize = 0; - status = rocfft_plan_get_work_buffer_size(*plan, &workbuffersize); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_get_work_buffer_size failed"); - - if (workbuffersize > 0) { - buffer = BufferViewType("work_buffer", workbuffersize); - status = rocfft_execution_info_set_work_buffer( - execution_info, (void*)buffer.data(), workbuffersize); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_execution_info_set_work_buffer failed"); - } - - status = rocfft_plan_description_destroy(description); - KOKKOSFFT_THROW_IF(status != rocfft_status_success, - "rocfft_plan_description_destroy failed"); + // Calculate the total size of the FFT + int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, + std::multiplies<>()); return fft_size; } -template , - std::nullptr_t> = nullptr> -void destroy_plan_and_info(std::unique_ptr& plan, - InfoType& execution_info) { - rocfft_execution_info_destroy(execution_info); - rocfft_plan_destroy(*plan); -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_ROCM_transform.hpp b/fft/src/KokkosFFT_ROCM_transform.hpp index 2c6d50b8..04d1ead8 100644 --- a/fft/src/KokkosFFT_ROCM_transform.hpp +++ b/fft/src/KokkosFFT_ROCM_transform.hpp @@ -8,63 +8,69 @@ #include #include #include "KokkosFFT_asserts.hpp" +#include "KokkosFFT_ROCM_types.hpp" namespace KokkosFFT { namespace Impl { -inline void exec_plan(rocfft_plan& plan, float* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan(const ScopedRocfftPlan& scoped_plan, float* idata, + std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for R2C failed"); } -inline void exec_plan(rocfft_plan& plan, double* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan(const ScopedRocfftPlan& scoped_plan, + double* idata, std::complex* odata, + int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for D2Z failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - float* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan( + const ScopedRocfftPlan>& scoped_plan, + std::complex* idata, float* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for C2R failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - double* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan( + const ScopedRocfftPlan>& scoped_plan, + std::complex* idata, double* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for Z2D failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan( + const ScopedRocfftPlan>& scoped_plan, + std::complex* idata, std::complex* odata, int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for C2C failed"); } -inline void exec_plan(rocfft_plan& plan, std::complex* idata, - std::complex* odata, int /*direction*/, - const rocfft_execution_info& execution_info) { +inline void exec_plan( + const ScopedRocfftPlan>& scoped_plan, + std::complex* idata, std::complex* odata, + int /*direction*/) { rocfft_status status = - rocfft_execute(plan, (void**)&idata, (void**)&odata, execution_info); + rocfft_execute(scoped_plan.plan(), (void**)&idata, (void**)&odata, + scoped_plan.execution_info()); KOKKOSFFT_THROW_IF(status != rocfft_status_success, "rocfft_execute for Z2Z failed"); } - } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index 60af7e57..ed8b06a8 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -5,9 +5,17 @@ #ifndef KOKKOSFFT_ROCM_TYPES_HPP #define KOKKOSFFT_ROCM_TYPES_HPP +#include +#include #include #include +#include #include "KokkosFFT_common_types.hpp" +#include "KokkosFFT_traits.hpp" +#include "KokkosFFT_asserts.hpp" +#if defined(ENABLE_HOST_AND_DEVICE) +#include "KokkosFFT_FFTW_Types.hpp" +#endif // Check the size of complex type static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); @@ -17,27 +25,255 @@ static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); static_assert(alignof(std::complex) <= alignof(Kokkos::complex)); -#ifdef ENABLE_HOST_AND_DEVICE -#include -#include "KokkosFFT_utils.hpp" -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); -#endif - namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; constexpr FFTDirectionType ROCFFT_FORWARD = 1; constexpr FFTDirectionType ROCFFT_BACKWARD = -1; +#if !defined(ENABLE_HOST_AND_DEVICE) enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; +#endif template using TransformType = FFTWTransformType; +/// \brief A class that wraps rocfft_plan_description for RAII +struct ScopedRocfftPlanDescription { + private: + rocfft_plan_description m_description; + + public: + ScopedRocfftPlanDescription() { + rocfft_status status = rocfft_plan_description_create(&m_description); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_description_create failed"); + } + ~ScopedRocfftPlanDescription() noexcept { + rocfft_status status = rocfft_plan_description_destroy(m_description); + if (status != rocfft_status_success) + Kokkos::abort("rocfft_plan_description_destroy failed"); + } + + ScopedRocfftPlanDescription(const ScopedRocfftPlanDescription &) = delete; + ScopedRocfftPlanDescription &operator=(const ScopedRocfftPlanDescription &) = + delete; + ScopedRocfftPlanDescription &operator=(ScopedRocfftPlanDescription &&) = + delete; + ScopedRocfftPlanDescription(ScopedRocfftPlanDescription &&) = delete; + + rocfft_plan_description description() const noexcept { return m_description; } +}; + +/// \brief A class that wraps rocfft_execution_info for RAII +template +struct ScopedRocfftExecutionInfo { + private: + using BufferViewType = + Kokkos::View *, Kokkos::HIP>; + rocfft_execution_info m_execution_info; + + //! Internal work buffer + BufferViewType m_buffer; + + public: + ScopedRocfftExecutionInfo() { + // Prepare workbuffer and set execution information + rocfft_status status = rocfft_execution_info_create(&m_execution_info); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_execution_info_create failed"); + } + ~ScopedRocfftExecutionInfo() noexcept { + rocfft_status status = rocfft_execution_info_destroy(m_execution_info); + if (status != rocfft_status_success) + Kokkos::abort("rocfft_execution_info_destroy failed"); + } + + ScopedRocfftExecutionInfo(const ScopedRocfftExecutionInfo &) = delete; + ScopedRocfftExecutionInfo &operator=(const ScopedRocfftExecutionInfo &) = + delete; + ScopedRocfftExecutionInfo &operator=(ScopedRocfftExecutionInfo &&) = delete; + ScopedRocfftExecutionInfo(ScopedRocfftExecutionInfo &&) = delete; + + rocfft_execution_info execution_info() const noexcept { + return m_execution_info; + } + + void setup(const Kokkos::HIP &exec_space, std::size_t workbuffersize) { + // set stream + // NOTE: The stream must be of type hipStream_t. + // It is an error to pass the address of a hipStream_t object. + hipStream_t stream = exec_space.hip_stream(); + rocfft_status status = + rocfft_execution_info_set_stream(m_execution_info, stream); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_execution_info_set_stream failed"); + + // Set work buffer + if (workbuffersize > 0) { + m_buffer = BufferViewType("workbuffer", workbuffersize); + status = rocfft_execution_info_set_work_buffer( + m_execution_info, (void *)m_buffer.data(), workbuffersize); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_execution_info_set_work_buffer failed"); + } + } +}; + +/// \brief A class that wraps rocfft for RAII +template +struct ScopedRocfftPlan { + private: + using floating_point_type = KokkosFFT::Impl::base_floating_point_type; + using ScopedRocfftExecutionInfoType = + ScopedRocfftExecutionInfo; + rocfft_precision m_precision = std::is_same_v + ? rocfft_precision_single + : rocfft_precision_double; + rocfft_plan m_plan; + std::unique_ptr m_execution_info; + + public: + ScopedRocfftPlan(const FFTWTransformType transform_type, + const std::vector &in_extents, + const std::vector &out_extents, + const std::vector &fft_extents, int howmany, + Direction direction, bool is_inplace) { + auto [in_array_type, out_array_type, fft_direction] = + get_in_out_array_type(transform_type, direction); + + // Compute dist and strides from extents + int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1, + std::multiplies<>()); + int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1, + std::multiplies<>()); + + auto in_strides = compute_strides(in_extents); + auto out_strides = compute_strides(out_extents); + auto reversed_fft_extents = + convert_int_type_and_reverse(fft_extents); + + // Create a plan description + ScopedRocfftPlanDescription scoped_description; + rocfft_status status = rocfft_plan_description_set_data_layout( + scoped_description.description(), // description handle + in_array_type, // input array type + out_array_type, // output array type + nullptr, // offsets to start of input data + nullptr, // offsets to start of output data + in_strides.size(), // input stride length + in_strides.data(), // input stride data + idist, // input batch distance + out_strides.size(), // output stride length + out_strides.data(), // output stride data + odist); // output batch distance + + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_description_set_data_layout failed"); + + // inplace or Out-of-place transform + const rocfft_result_placement place = + is_inplace ? rocfft_placement_inplace : rocfft_placement_notinplace; + + // Create a plan + status = rocfft_plan_create(&m_plan, place, fft_direction, m_precision, + reversed_fft_extents.size(), // Dimension + reversed_fft_extents.data(), // Lengths + howmany, // Number of transforms + scoped_description.description() // Description + ); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_create failed"); + } + ~ScopedRocfftPlan() noexcept { + rocfft_status status = rocfft_plan_destroy(m_plan); + if (status != rocfft_status_success) + Kokkos::abort("rocfft_plan_destroy failed"); + } + + ScopedRocfftPlan() = delete; + ScopedRocfftPlan(const ScopedRocfftPlan &) = delete; + ScopedRocfftPlan &operator=(const ScopedRocfftPlan &) = delete; + ScopedRocfftPlan &operator=(ScopedRocfftPlan &&) = delete; + ScopedRocfftPlan(ScopedRocfftPlan &&) = delete; + + rocfft_plan plan() const noexcept { return m_plan; } + rocfft_execution_info execution_info() const noexcept { + return m_execution_info->execution_info(); + } + + void commit(const Kokkos::HIP &exec_space) { + std::size_t workbuffersize = 0; + rocfft_status status = + rocfft_plan_get_work_buffer_size(m_plan, &workbuffersize); + KOKKOSFFT_THROW_IF(status != rocfft_status_success, + "rocfft_plan_get_work_buffer_size failed"); + + m_execution_info = std::make_unique(); + m_execution_info->setup(exec_space, workbuffersize); + } + + // Helper to get input and output array type and direction from transform type + auto get_in_out_array_type(FFTWTransformType type, Direction direction) { + rocfft_array_type in_array_type, out_array_type; + rocfft_transform_type fft_direction; + + if (type == FFTWTransformType::C2C || type == FFTWTransformType::Z2Z) { + in_array_type = rocfft_array_type_complex_interleaved; + out_array_type = rocfft_array_type_complex_interleaved; + fft_direction = direction == Direction::forward + ? rocfft_transform_type_complex_forward + : rocfft_transform_type_complex_inverse; + } else if (type == FFTWTransformType::R2C || + type == FFTWTransformType::D2Z) { + in_array_type = rocfft_array_type_real; + out_array_type = rocfft_array_type_hermitian_interleaved; + fft_direction = rocfft_transform_type_real_forward; + } else if (type == FFTWTransformType::C2R || + type == FFTWTransformType::Z2D) { + in_array_type = rocfft_array_type_hermitian_interleaved; + out_array_type = rocfft_array_type_real; + fft_direction = rocfft_transform_type_real_inverse; + } + + return std::tuple( + {in_array_type, out_array_type, fft_direction}); + }; + + // Helper to convert the integer type of vectors + template + auto convert_int_type_and_reverse(const std::vector &in) + -> std::vector { + std::vector out(in.size()); + std::transform( + in.cbegin(), in.cend(), out.begin(), + [](const InType v) -> OutType { return static_cast(v); }); + + std::reverse(out.begin(), out.end()); + return out; + } + + // Helper to compute strides from extents + // (n0, n1, n2) -> (1, n0, n0*n1) + // (n0, n1) -> (1, n0) + // (n0) -> (1) + template + auto compute_strides(const std::vector &extents) + -> std::vector { + std::vector out = {1}; + auto reversed_extents = extents; + std::reverse(reversed_extents.begin(), reversed_extents.end()); + + for (std::size_t i = 1; i < reversed_extents.size(); i++) { + out.push_back(static_cast(reversed_extents.at(i - 1)) * + out.at(i - 1)); + } + + return out; + } +}; + // Define fft transform types template struct transform_type { @@ -76,7 +312,7 @@ struct transform_type, static constexpr FFTWTransformType type() { return m_type; }; }; -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) template struct FFTDataType { @@ -92,18 +328,12 @@ struct FFTDataType { template struct FFTPlanType { - using fftwHandle = std::conditional_t< - std::is_same_v, float>, - fftwf_plan, fftw_plan>; + using fftw_plan_type = ScopedFFTWPlan; + using rocfft_plan_type = ScopedRocfftPlan; using type = std::conditional_t, - rocfft_plan, fftwHandle>; + rocfft_plan_type, fftw_plan_type>; }; -template -using FFTInfoType = - std::conditional_t, - rocfft_execution_info, int>; - template auto direction_type(Direction direction) { static constexpr FFTDirectionType FORWARD = @@ -126,12 +356,9 @@ struct FFTDataType { template struct FFTPlanType { - using type = rocfft_plan; + using type = ScopedRocfftPlan; }; -template -using FFTInfoType = rocfft_execution_info; - template auto direction_type(Direction direction) { return direction == Direction::forward ? ROCFFT_FORWARD : ROCFFT_BACKWARD; diff --git a/fft/src/KokkosFFT_SYCL_plans.hpp b/fft/src/KokkosFFT_SYCL_plans.hpp index 52c751ef..f386c269 100644 --- a/fft/src/KokkosFFT_SYCL_plans.hpp +++ b/fft/src/KokkosFFT_SYCL_plans.hpp @@ -46,15 +46,14 @@ auto compute_strides(std::vector& extents) -> std::vector { // batched transform, over ND Views template < typename ExecutionSpace, typename PlanType, typename InViewType, - typename OutViewType, typename BufferViewType, typename InfoType, - std::size_t fft_rank = 1, + typename OutViewType, std::size_t fft_rank = 1, std::enable_if_t, std::nullptr_t> = nullptr> auto create_plan(const ExecutionSpace& exec_space, std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, BufferViewType&, InfoType&, - Direction /*direction*/, axis_type axes, - shape_type s, bool is_inplace) { + const OutViewType& out, Direction /*direction*/, + axis_type axes, shape_type s, + bool is_inplace) { static_assert( KokkosFFT::Impl::are_operatable_views_v, @@ -109,14 +108,6 @@ auto create_plan(const ExecutionSpace& exec_space, return fft_size; } - -template < - typename ExecutionSpace, typename PlanType, typename InfoType, - std::enable_if_t, - std::nullptr_t> = nullptr> -void destroy_plan_and_info(std::unique_ptr&, InfoType&) { - // In oneMKL, plans are destroyed by destructor -} } // namespace Impl } // namespace KokkosFFT diff --git a/fft/src/KokkosFFT_SYCL_types.hpp b/fft/src/KokkosFFT_SYCL_types.hpp index df359aca..9c6cb86b 100644 --- a/fft/src/KokkosFFT_SYCL_types.hpp +++ b/fft/src/KokkosFFT_SYCL_types.hpp @@ -12,6 +12,10 @@ #include "KokkosFFT_common_types.hpp" #include "KokkosFFT_utils.hpp" +#if defined(ENABLE_HOST_AND_DEVICE) +#include "KokkosFFT_FFTW_Types.hpp" +#endif + // Check the size of complex type // [TO DO] I guess this kind of test is already made by Kokkos itself static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); @@ -21,26 +25,15 @@ static_assert(sizeof(std::complex) == sizeof(Kokkos::complex)); static_assert(alignof(std::complex) <= alignof(Kokkos::complex)); -#ifdef ENABLE_HOST_AND_DEVICE -#include -static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex)); - -static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex)); -static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex)); -#endif - namespace KokkosFFT { namespace Impl { using FFTDirectionType = int; constexpr FFTDirectionType MKL_FFT_FORWARD = 1; constexpr FFTDirectionType MKL_FFT_BACKWARD = -1; -// Unused -template -using FFTInfoType = int; - +#if !defined(ENABLE_HOST_AND_DEVICE) enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z }; +#endif template using TransformType = FFTWTransformType; @@ -83,7 +76,7 @@ struct transform_type, static constexpr FFTWTransformType type() { return m_type; }; }; -#ifdef ENABLE_HOST_AND_DEVICE +#if defined(ENABLE_HOST_AND_DEVICE) template struct FFTDataType { @@ -115,11 +108,7 @@ struct FFTPlanType> { static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::REAL; - using fftwHandle = std::conditional_t< - std::is_same_v, - float>, - fftwf_plan, fftw_plan>; - + using fftwHandle = ScopedFFTWPlan>; using onemklHandle = oneapi::mkl::dft::descriptor; using type = std::conditional_t< std::is_same_v, onemklHandle, @@ -137,11 +126,7 @@ struct FFTPlanType, T2> { static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::REAL; - using fftwHandle = std::conditional_t< - std::is_same_v, - float>, - fftwf_plan, fftw_plan>; - + using fftwHandle = ScopedFFTWPlan, T2>; using onemklHandle = oneapi::mkl::dft::descriptor; using type = std::conditional_t< std::is_same_v, onemklHandle, @@ -159,11 +144,8 @@ struct FFTPlanType, Kokkos::complex> { static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::COMPLEX; - using fftwHandle = std::conditional_t< - std::is_same_v, - float>, - fftwf_plan, fftw_plan>; - + using fftwHandle = + ScopedFFTWPlan, Kokkos::complex>; using onemklHandle = oneapi::mkl::dft::descriptor; using type = std::conditional_t< std::is_same_v, onemklHandle,