From 367e7ccd8d87e449d3072e35433e75c5916fad0f Mon Sep 17 00:00:00 2001 From: Thomas Padioleau Date: Tue, 5 Nov 2024 13:30:47 +0100 Subject: [PATCH] Pass views by const& in fftshift (#192) * Pass views by const& in fftshift * Review from yasahi-hpc Co-authored-by: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> --------- Co-authored-by: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> --- common/src/KokkosFFT_Helpers.hpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/common/src/KokkosFFT_Helpers.hpp b/common/src/KokkosFFT_Helpers.hpp index f3e06b4b..7b06aa52 100644 --- a/common/src/KokkosFFT_Helpers.hpp +++ b/common/src/KokkosFFT_Helpers.hpp @@ -39,8 +39,8 @@ auto get_shift(const ViewType& inout, axis_type axes, int direction = 1) { } template -void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift, - axis_type<1>) { +void roll(const ExecutionSpace& exec_space, const ViewType& inout, + axis_type<1> shift, axis_type<1> /* axes */) { // Last parameter is ignored but present for keeping the interface consistent static_assert(ViewType::rank() == 1, "roll: Rank of View must be 1."); int n0 = inout.extent_int(0); @@ -67,13 +67,13 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift, } }); - inout = tmp; + Kokkos::deep_copy(inout, tmp); } } template -void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift, - axis_type axes) { +void roll(const ExecutionSpace& exec_space, const ViewType& inout, + axis_type<2> shift, axis_type axes) { constexpr int DIM0 = 2; static_assert(ViewType::rank() == DIM0, "roll: Rank of View must be 2."); int n0 = inout.extent_int(0), n1 = inout.extent_int(1); @@ -128,18 +128,18 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift, } }); - inout = tmp; + Kokkos::deep_copy(inout, tmp); } template -void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout, +void fftshift_impl(const ExecutionSpace& exec_space, const ViewType& inout, axis_type axes) { auto shift = get_shift(inout, axes); roll(exec_space, inout, shift, axes); } template -void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout, +void ifftshift_impl(const ExecutionSpace& exec_space, const ViewType& inout, axis_type axes) { auto shift = get_shift(inout, axes, -1); roll(exec_space, inout, shift, axes); @@ -219,7 +219,7 @@ auto rfftfreq(const ExecutionSpace&, const std::size_t n, /// \param axes [in] Axes over which to shift (default: nullopt, shifting over /// all axes) template -void fftshift(const ExecutionSpace& exec_space, ViewType& inout, +void fftshift(const ExecutionSpace& exec_space, const ViewType& inout, std::optional axes = std::nullopt) { static_assert(KokkosFFT::Impl::is_operatable_view_v, "fftshift: View value type must be float, double, " @@ -246,7 +246,7 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout, /// \param inout [in,out] Spectrum /// \param axes [in] Axes over which to shift template -void fftshift(const ExecutionSpace& exec_space, ViewType& inout, +void fftshift(const ExecutionSpace& exec_space, const ViewType& inout, axis_type axes) { static_assert(KokkosFFT::Impl::is_operatable_view_v, "fftshift: View value type must be float, double, " @@ -269,7 +269,7 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout, /// \param axes [in] Axes over which to shift (default: nullopt, shifting over /// all axes) template -void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, +void ifftshift(const ExecutionSpace& exec_space, const ViewType& inout, std::optional axes = std::nullopt) { static_assert(KokkosFFT::Impl::is_operatable_view_v, "ifftshift: View value type must be float, double, " @@ -295,7 +295,7 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, /// \param inout [in,out] Spectrum /// \param axes [in] Axes over which to shift template -void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, +void ifftshift(const ExecutionSpace& exec_space, const ViewType& inout, axis_type axes) { static_assert(KokkosFFT::Impl::is_operatable_view_v, "ifftshift: View value type must be float, double, "