From 18d113f49a00ac5d0dd068e0262a5c9b039b8319 Mon Sep 17 00:00:00 2001 From: yasahi-hpc <57478230+yasahi-hpc@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:22:07 +0100 Subject: [PATCH] Add missing checks (#196) * Add assertions for completely mismatched extents * Raise an error if inplace plan is executed on out-of-place views * remove unnecessary fence --------- Co-authored-by: Yuuichi Asahi --- common/src/KokkosFFT_Extents.hpp | 16 ++++++++++++ common/unit_test/Test_Extents.cpp | 43 +++++++++++++++++++++++++++++++ fft/src/KokkosFFT_Plans.hpp | 5 ++++ fft/unit_test/Test_Transform.cpp | 39 ++++++++++++++++++++++++++-- 4 files changed, 101 insertions(+), 2 deletions(-) diff --git a/common/src/KokkosFFT_Extents.hpp b/common/src/KokkosFFT_Extents.hpp index e3bea237..e235d05e 100644 --- a/common/src/KokkosFFT_Extents.hpp +++ b/common/src/KokkosFFT_Extents.hpp @@ -75,6 +75,22 @@ auto get_extents(const InViewType& in, const OutViewType& out, static_assert(!(is_real_v && is_real_v), "get_extents: real to real transform is not supported"); + for (std::size_t i = 0; i < rank; i++) { + // The requirement for inner_most_axis is different for transform type + if (static_cast(i) == inner_most_axis) continue; + KOKKOSFFT_THROW_IF(in_extents_full.at(i) != out_extents_full.at(i), + "input and output extents must be the same except for " + "the transform axis"); + } + + if constexpr (is_complex_v && is_complex_v) { + // Then C2C + KOKKOSFFT_THROW_IF( + in_extents_full.at(inner_most_axis) != + out_extents_full.at(inner_most_axis), + "input and output extents must be the same for C2C transform"); + } + if constexpr (is_real_v) { // Then R2C if (is_inplace) { diff --git a/common/unit_test/Test_Extents.cpp b/common/unit_test/Test_Extents.cpp index 28ddbc97..312af11a 100644 --- a/common/unit_test/Test_Extents.cpp +++ b/common/unit_test/Test_Extents.cpp @@ -176,6 +176,14 @@ void test_extents_1d_batched_FFT_2d() { EXPECT_TRUE(fft_extents_c2c_axis1 == ref_fft_extents_r2c_axis1); EXPECT_TRUE(out_extents_c2c_axis1 == ref_in_extents_r2c_axis1); EXPECT_EQ(howmany_c2c_axis1, ref_howmany_r2c_axis1); + + // Check if errors are correctly raised aginst invalid extents + ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1); + for (int i = 0; i < 2; i++) { + EXPECT_THROW( + { KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong, axes_type({i})); }, + std::runtime_error); + } } template @@ -306,6 +314,14 @@ void test_extents_1d_batched_FFT_3d() { EXPECT_TRUE(fft_extents_c2c_axis2 == ref_fft_extents_r2c_axis2); EXPECT_TRUE(out_extents_c2c_axis2 == ref_in_extents_r2c_axis2); EXPECT_EQ(howmany_c2c_axis2, ref_howmany_r2c_axis2); + + // Check if errors are correctly raised aginst invalid extents + ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2); + for (int i = 0; i < 3; i++) { + EXPECT_THROW( + { KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong, axes_type({i})); }, + std::runtime_error); + } } TYPED_TEST(Extents1D, 1DFFT_1DView) { @@ -429,6 +445,20 @@ void test_extents_2d() { EXPECT_EQ(howmany_c2c_axis01, 1); EXPECT_EQ(howmany_c2c_axis10, 1); + + // Check if errors are correctly raised aginst invalid extents + ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1); + for (int axis0 = 0; axis0 < 2; axis0++) { + for (int axis1 = 0; axis1 < 2; axis1++) { + if (axis0 == axis1) continue; + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong, + axes_type({axis0, axis1})); + }, + std::runtime_error); + } + } } template @@ -709,6 +739,19 @@ void test_extents_2d_batched_FFT_3d() { EXPECT_TRUE(fft_extents_c2c_axis_21 == ref_fft_extents_r2c_axis_21); EXPECT_TRUE(out_extents_c2c_axis_21 == ref_in_extents_r2c_axis_21); EXPECT_EQ(howmany_c2c_axis_21, ref_howmany_r2c_axis_21); + + ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2 + 2); + for (int axis0 = 0; axis0 < 3; axis0++) { + for (int axis1 = 0; axis1 < 3; axis1++) { + if (axis0 == axis1) continue; + EXPECT_THROW( + { + KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong, + axes_type({axis0, axis1})); + }, + std::runtime_error); + } + } } TYPED_TEST(Extents2D, 2DFFT_2DView) { diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 1025e006..5ec1e744 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -394,6 +394,11 @@ class Plan { KOKKOSFFT_THROW_IF(out_extents != m_out_extents, "extents of output View for plan and " "execution are not identical."); + + bool is_inplace = KokkosFFT::Impl::are_aliasing(in.data(), out.data()); + KOKKOSFFT_THROW_IF(is_inplace != m_is_inplace, + "If the plan is in-place, the input and output Views " + "must be identical."); } }; } // namespace KokkosFFT diff --git a/fft/unit_test/Test_Transform.cpp b/fft/unit_test/Test_Transform.cpp index 1a793533..60b36ba3 100644 --- a/fft/unit_test/Test_Transform.cpp +++ b/fft/unit_test/Test_Transform.cpp @@ -103,8 +103,6 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) { Kokkos::deep_copy(a_ref, a); Kokkos::deep_copy(ar_ref, ar); - Kokkos::fence(); - KokkosFFT::fft(execution_space(), a, a_hat); KokkosFFT::ifft(execution_space(), a_hat, inv_a_hat); @@ -115,6 +113,43 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) { EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol)); EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol)); + + // Create a plan for inplace transform + Kokkos::deep_copy(a_ref, a); + Kokkos::deep_copy(ar_ref, ar); + + int axis = -1; + KokkosFFT::Plan fft_plan(execution_space(), a, a_hat, + KokkosFFT::Direction::forward, axis); + fft_plan.execute(a, a_hat); + + KokkosFFT::Plan ifft_plan(execution_space(), a_hat, inv_a_hat, + KokkosFFT::Direction::backward, axis); + ifft_plan.execute(a_hat, inv_a_hat); + + KokkosFFT::Plan rfft_plan(execution_space(), ar, ar_hat, + KokkosFFT::Direction::forward, axis); + rfft_plan.execute(ar, ar_hat); + + KokkosFFT::Plan irfft_plan(execution_space(), ar_hat, inv_ar_hat, + KokkosFFT::Direction::backward, axis); + irfft_plan.execute(ar_hat, inv_ar_hat); + + EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol)); + EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol)); + + // inplace Plan cannot be reused for out-of-place case + ComplexView1DType a_hat_out("a_hat_out", i), + inv_a_hat_out("inv_a_hat_out", i); + + RealView1DType inv_ar_hat_out("inv_ar_hat_out", i); + ComplexView1DType ar_hat_out("ar_hat_out", i / 2 + 1); + EXPECT_THROW(fft_plan.execute(a, a_hat_out), std::runtime_error); + EXPECT_THROW(ifft_plan.execute(a_hat_out, inv_a_hat_out), + std::runtime_error); + EXPECT_THROW(rfft_plan.execute(ar, ar_hat_out), std::runtime_error); + EXPECT_THROW(irfft_plan.execute(ar_hat_out, inv_ar_hat_out), + std::runtime_error); } }