diff --git a/common/src/KokkosFFT_normalization.hpp b/common/src/KokkosFFT_normalization.hpp index b57e974d..9e9c01ea 100644 --- a/common/src/KokkosFFT_normalization.hpp +++ b/common/src/KokkosFFT_normalization.hpp @@ -26,8 +26,8 @@ void normalize_impl(const ExecutionSpace& exec_space, ViewType& inout, template auto get_coefficients(ViewType, Direction direction, Normalization normalization, std::size_t fft_size) { - using value_type = - KokkosFFT::Impl::real_type_t; + using value_type = KokkosFFT::Impl::base_floating_point_type< + typename ViewType::non_const_value_type>; value_type coef = 1; [[maybe_unused]] bool to_normalize = false; diff --git a/common/src/KokkosFFT_traits.hpp b/common/src/KokkosFFT_traits.hpp new file mode 100644 index 00000000..ec89df49 --- /dev/null +++ b/common/src/KokkosFFT_traits.hpp @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file +// +// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception + +#ifndef KOKKOSFFT_TRAITS_HPP +#define KOKKOSFFT_TRAITS_HPP + +#include + +namespace KokkosFFT { +namespace Impl { +template +struct base_floating_point { + using value_type = T; +}; + +template +struct base_floating_point> { + using value_type = T; +}; + +/// \brief Helper to extract the base floating point type from a complex type +template +using base_floating_point_type = typename base_floating_point::value_type; + +template +struct is_real : std::false_type {}; + +template +struct is_real< + T, std::enable_if_t || std::is_same_v>> + : std::true_type {}; + +/// \brief Helper to check if a type is an acceptable real type (float/double) +/// for Kokkos-FFT +template +inline constexpr bool is_real_v = is_real::value; + +template +struct is_complex : std::false_type {}; + +template +struct is_complex< + Kokkos::complex, + std::enable_if_t || std::is_same_v>> + : std::true_type {}; + +/// \brief Helper to check if a type is an acceptable complex type +/// (Kokkos::complex/Kokkos::complex) for Kokkos-FFT +template +inline constexpr bool is_complex_v = is_complex::value; + +// is value type admissible for KokkosFFT +template +struct is_admissible_value_type : std::false_type {}; + +template +struct is_admissible_value_type< + T, std::enable_if_t || is_complex_v>> : std::true_type {}; + +template +struct is_admissible_value_type< + T, std::enable_if_t::value && + (is_real_v || + is_complex_v)>> + : std::true_type {}; + +/// \brief Helper to check if a type is an acceptable value type +/// (float/double/Kokkos::complex/Kokkos::complex) for Kokkos-FFT +/// When applied to Kokkos::View, then check if a value type is an +/// acceptable real/complex type. +template +inline constexpr bool is_admissible_value_type_v = + is_admissible_value_type::value; + +template +struct is_layout_left_or_right : std::false_type {}; + +template +struct is_layout_left_or_right< + ViewType, + std::enable_if_t< + Kokkos::is_view::value && + (std::is_same_v || + std::is_same_v)>> + : std::true_type {}; + +/// \brief Helper to check if a View layout is an acceptable layout type +/// (Kokkos::LayoutLeft/Kokkos::LayoutRight) for Kokkos-FFT +template +inline constexpr bool is_layout_left_or_right_v = + is_layout_left_or_right::value; + +template +struct is_admissible_view : std::false_type {}; + +template +struct is_admissible_view< + ViewType, std::enable_if_t::value && + is_layout_left_or_right_v && + is_admissible_value_type_v>> + : std::true_type {}; + +/// \brief Helper to check if a View is an acceptable for Kokkos-FFT. Values and +/// layout are checked +template +inline constexpr bool is_admissible_view_v = + is_admissible_view::value; + +/// \brief Helper to define a managable View type from the original view type +template +struct managable_view_type { + using type = Kokkos::View>; +}; + +/// \brief Helper to define a complex 1D View type from a real/complex 1D View +/// type, while keeping other properties +template = nullptr> +struct complex_view_type { + using value_type = typename ViewType::non_const_value_type; + using float_type = KokkosFFT::Impl::base_floating_point_type; + using complex_type = Kokkos::complex; + using array_layout_type = typename ViewType::array_layout; + using type = Kokkos::View; +}; + +} // namespace Impl +} // namespace KokkosFFT + +#endif diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index 5f362316..ea3b990f 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -10,60 +10,10 @@ #include #include #include +#include "KokkosFFT_traits.hpp" namespace KokkosFFT { namespace Impl { -template -struct real_type { - using type = T; -}; - -template -struct real_type> { - using type = T; -}; - -template -struct managable_view_type { - using type = Kokkos::View>; -}; - -template -using real_type_t = typename real_type::type; - -template -struct is_complex : std::false_type {}; - -template -struct is_complex> : std::true_type {}; - -template -struct is_layout_left_or_right : std::false_type {}; - -template -struct is_layout_left_or_right< - ViewType, - std::enable_if_t< - std::is_same_v || - std::is_same_v>> - : std::true_type {}; - -template -inline constexpr bool is_layout_left_or_right_v = - is_layout_left_or_right::value; - -template = nullptr> -struct complex_view_type { - using value_type = typename ViewType::non_const_value_type; - using float_type = KokkosFFT::Impl::real_type_t; - using complex_type = Kokkos::complex; - using array_layout_type = typename ViewType::array_layout; - using type = Kokkos::View; -}; template auto convert_negative_axis(ViewType, int _axis = -1) { diff --git a/common/unit_test/CMakeLists.txt b/common/unit_test/CMakeLists.txt index 2032dc7a..5d5ff9f7 100644 --- a/common/unit_test/CMakeLists.txt +++ b/common/unit_test/CMakeLists.txt @@ -5,6 +5,7 @@ add_executable(unit-tests-kokkos-fft-common Test_Main.cpp Test_Utils.cpp + Test_Traits.cpp Test_Normalization.cpp Test_Transpose.cpp Test_Layouts.cpp @@ -20,4 +21,4 @@ target_link_libraries(unit-tests-kokkos-fft-common PUBLIC common GTest::gtest) # Enable GoogleTest include(GoogleTest) -gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600) +gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600) \ No newline at end of file diff --git a/common/unit_test/Test_Traits.cpp b/common/unit_test/Test_Traits.cpp new file mode 100644 index 00000000..01eb2e02 --- /dev/null +++ b/common/unit_test/Test_Traits.cpp @@ -0,0 +1,170 @@ +// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file +// +// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception + +#include +#include "KokkosFFT_traits.hpp" +#include "Test_Utils.hpp" + +using real_types = ::testing::Types; +using view_types = + ::testing::Types, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair, + std::pair>; + +template +struct RealAndComplexTypes : public ::testing::Test { + using real_type = T; + using complex_type = Kokkos::complex; +}; + +template +struct RealAndComplexViewTypes : public ::testing::Test { + using real_type = typename T::first_type; + using complex_type = Kokkos::complex; + using layout_type = typename T::second_type; +}; + +TYPED_TEST_SUITE(RealAndComplexTypes, real_types); +TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types); + +// Tests for real type deduction +template +void test_get_real_type() { + using real_type_from_RealType = + KokkosFFT::Impl::base_floating_point_type; + using real_type_from_ComplexType = + KokkosFFT::Impl::base_floating_point_type; + + static_assert(std::is_same_v, + "Real type not deduced correctly from real type"); + static_assert(std::is_same_v, + "Real type not deduced correctly from complex type"); +} + +// Tests for admissible real types (float or double) +template +void test_admissible_real_type() { + if constexpr (std::is_same_v || std::is_same_v) { + static_assert(KokkosFFT::Impl::is_real_v, + "Real type must be float or double"); + } else { + static_assert(!KokkosFFT::Impl::is_real_v, + "Real type must be float or double"); + } +} + +template +void test_admissible_complex_type() { + using real_type = KokkosFFT::Impl::base_floating_point_type; + if constexpr (std::is_same_v || + std::is_same_v) { + static_assert(KokkosFFT::Impl::is_complex_v, + "Complex type must be Kokkos::complex or " + "Kokkos::complex"); + } else { + static_assert(!KokkosFFT::Impl::is_complex_v, + "Complex type must be Kokkos::complex or " + "Kokkos::complex"); + } +} + +TYPED_TEST(RealAndComplexTypes, get_real_type) { + using real_type = typename TestFixture::real_type; + using complex_type = typename TestFixture::complex_type; + + test_get_real_type(); +} + +TYPED_TEST(RealAndComplexTypes, admissible_real_type) { + using real_type = typename TestFixture::real_type; + + test_admissible_real_type(); +} + +TYPED_TEST(RealAndComplexTypes, admissible_complex_type) { + using complex_type = typename TestFixture::complex_type; + + test_admissible_complex_type(); +} + +// Tests for admissible view types +template +void test_admissible_value_type() { + using ViewType = Kokkos::View; + using real_type = KokkosFFT::Impl::base_floating_point_type; + if constexpr (std::is_same_v || + std::is_same_v) { + static_assert(KokkosFFT::Impl::is_admissible_value_type_v, + "Real type must be float or double"); + } else { + static_assert(!KokkosFFT::Impl::is_admissible_value_type_v, + "Real type must be float or double"); + } +} + +template +void test_admissible_layout_type() { + using ViewType = Kokkos::View; + if constexpr (std::is_same_v || + std::is_same_v) { + static_assert(KokkosFFT::Impl::is_layout_left_or_right_v, + "View Layout must be either LayoutLeft or LayoutRight."); + } else { + static_assert(!KokkosFFT::Impl::is_layout_left_or_right_v, + "View Layout must be either LayoutLeft or LayoutRight."); + } +} + +template +void test_admissible_view_type() { + using ViewType = Kokkos::View; + using real_type = KokkosFFT::Impl::base_floating_point_type; + if constexpr ( + (std::is_same_v || std::is_same_v)&&( + std::is_same_v || + std::is_same_v)) { + static_assert(KokkosFFT::Impl::is_admissible_view_v, + "View value type must be float, double, " + "Kokkos::Complex, Kokkos::Complex. Layout " + "must be either LayoutLeft or LayoutRight."); + } else { + static_assert(!KokkosFFT::Impl::is_admissible_view_v, + "View value type must be float, double, " + "Kokkos::Complex, Kokkos::Complex. Layout " + "must be either LayoutLeft or LayoutRight."); + } +} + +TYPED_TEST(RealAndComplexViewTypes, admissible_value_type) { + using real_type = typename TestFixture::real_type; + using complex_type = typename TestFixture::complex_type; + using layout_type = typename TestFixture::layout_type; + + test_admissible_value_type(); + test_admissible_value_type(); +} + +TYPED_TEST(RealAndComplexViewTypes, admissible_layout_type) { + using real_type = typename TestFixture::real_type; + using complex_type = typename TestFixture::complex_type; + using layout_type = typename TestFixture::layout_type; + + test_admissible_layout_type(); + test_admissible_layout_type(); +} + +TYPED_TEST(RealAndComplexViewTypes, admissible_view_type) { + using real_type = typename TestFixture::real_type; + using complex_type = typename TestFixture::complex_type; + using layout_type = typename TestFixture::layout_type; + + test_admissible_view_type(); + test_admissible_view_type(); +} diff --git a/common/unit_test/Test_Types.hpp b/common/unit_test/Test_Types.hpp index 69d38011..ab7c6011 100644 --- a/common/unit_test/Test_Types.hpp +++ b/common/unit_test/Test_Types.hpp @@ -4,7 +4,7 @@ #ifndef TEST_TYPES_HPP #define TEST_TYPES_HPP - +#include #include using execution_space = Kokkos::DefaultExecutionSpace; template diff --git a/fft/src/KokkosFFT_Cuda_types.hpp b/fft/src/KokkosFFT_Cuda_types.hpp index 076092f7..0a66538a 100644 --- a/fft/src/KokkosFFT_Cuda_types.hpp +++ b/fft/src/KokkosFFT_Cuda_types.hpp @@ -55,8 +55,8 @@ struct FFTDataType { template struct FFTPlanType { using fftwHandle = std::conditional_t< - std::is_same_v, float>, fftwf_plan, - fftw_plan>; + std::is_same_v, float>, + fftwf_plan, fftw_plan>; using type = std::conditional_t, cufftHandle, fftwHandle>; }; diff --git a/fft/src/KokkosFFT_HIP_types.hpp b/fft/src/KokkosFFT_HIP_types.hpp index 682e41f0..d460fc15 100644 --- a/fft/src/KokkosFFT_HIP_types.hpp +++ b/fft/src/KokkosFFT_HIP_types.hpp @@ -55,8 +55,8 @@ struct FFTDataType { template struct FFTPlanType { using fftwHandle = std::conditional_t< - std::is_same_v, float>, fftwf_plan, - fftw_plan>; + std::is_same_v, float>, + fftwf_plan, fftw_plan>; using type = std::conditional_t, hipfftHandle, fftwHandle>; }; diff --git a/fft/src/KokkosFFT_Host_plans.hpp b/fft/src/KokkosFFT_Host_plans.hpp index 2570c0b9..6926b960 100644 --- a/fft/src/KokkosFFT_Host_plans.hpp +++ b/fft/src/KokkosFFT_Host_plans.hpp @@ -50,7 +50,8 @@ auto create_plan(const ExecutionSpace& exec_space, "KokkosFFT::create_plan: Rank of View must be larger than Rank of FFT."); const int rank = fft_rank; - init_threads>( + init_threads>( exec_space); constexpr auto type = diff --git a/fft/src/KokkosFFT_Host_types.hpp b/fft/src/KokkosFFT_Host_types.hpp index 7c29d843..0bcb8c92 100644 --- a/fft/src/KokkosFFT_Host_types.hpp +++ b/fft/src/KokkosFFT_Host_types.hpp @@ -37,8 +37,8 @@ struct FFTDataType { template struct FFTPlanType { using type = std::conditional_t< - std::is_same_v, float>, fftwf_plan, - fftw_plan>; + std::is_same_v, float>, + fftwf_plan, fftw_plan>; }; template diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 641fa006..34049d32 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -67,7 +67,7 @@ class Plan { using out_value_type = typename OutViewType::non_const_value_type; //! The real value type of input/output views - using float_type = KokkosFFT::Impl::real_type_t; + using float_type = KokkosFFT::Impl::base_floating_point_type; //! The layout type of input/output views using layout_type = typename InViewType::array_layout; diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index b1ea115d..2d44ba59 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -42,7 +42,8 @@ auto get_in_out_array_type(TransformType type, template rocfft_precision get_in_out_array_type() { - return std::is_same_v, float> + return std::is_same_v, + float> ? rocfft_precision_single : rocfft_precision_double; } diff --git a/fft/src/KokkosFFT_ROCM_types.hpp b/fft/src/KokkosFFT_ROCM_types.hpp index 0059e1ad..0d11c259 100644 --- a/fft/src/KokkosFFT_ROCM_types.hpp +++ b/fft/src/KokkosFFT_ROCM_types.hpp @@ -99,8 +99,8 @@ struct FFTDataType { template struct FFTPlanType { using fftwHandle = std::conditional_t< - std::is_same_v, float>, fftwf_plan, - fftw_plan>; + std::is_same_v, float>, + fftwf_plan, fftw_plan>; using type = std::conditional_t, rocfft_plan, fftwHandle>; }; diff --git a/fft/src/KokkosFFT_SYCL_types.hpp b/fft/src/KokkosFFT_SYCL_types.hpp index db5be59f..c36d5ad2 100644 --- a/fft/src/KokkosFFT_SYCL_types.hpp +++ b/fft/src/KokkosFFT_SYCL_types.hpp @@ -114,14 +114,16 @@ template struct FFTPlanType> { using float_type = T1; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::REAL; using fftwHandle = std::conditional_t< - std::is_same_v, float>, + std::is_same_v, + float>, fftwf_plan, fftw_plan>; using onemklHandle = oneapi::mkl::dft::descriptor; @@ -134,14 +136,16 @@ template struct FFTPlanType, T2> { using float_type = T2; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::REAL; using fftwHandle = std::conditional_t< - std::is_same_v, float>, + std::is_same_v, + float>, fftwf_plan, fftw_plan>; using onemklHandle = oneapi::mkl::dft::descriptor; @@ -152,16 +156,18 @@ struct FFTPlanType, T2> { template struct FFTPlanType, Kokkos::complex> { - using float_type = KokkosFFT::Impl::real_type_t; + using float_type = KokkosFFT::Impl::base_floating_point_type; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = oneapi::mkl::dft::domain::COMPLEX; using fftwHandle = std::conditional_t< - std::is_same_v, float>, + std::is_same_v, + float>, fftwf_plan, fftw_plan>; using onemklHandle = oneapi::mkl::dft::descriptor; @@ -201,7 +207,8 @@ template struct FFTPlanType> { using float_type = T1; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = @@ -214,7 +221,8 @@ template struct FFTPlanType, T2> { using float_type = T2; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = @@ -225,9 +233,10 @@ struct FFTPlanType, T2> { template struct FFTPlanType, Kokkos::complex> { - using float_type = KokkosFFT::Impl::real_type_t; + using float_type = KokkosFFT::Impl::base_floating_point_type; static constexpr oneapi::mkl::dft::precision prec = - std::is_same_v, float> + std::is_same_v, + float> ? oneapi::mkl::dft::precision::SINGLE : oneapi::mkl::dft::precision::DOUBLE; static constexpr oneapi::mkl::dft::domain dom = diff --git a/fft/unit_test/Test_Transform.cpp b/fft/unit_test/Test_Transform.cpp index b5db0655..4bcb5f49 100644 --- a/fft/unit_test/Test_Transform.cpp +++ b/fft/unit_test/Test_Transform.cpp @@ -22,7 +22,7 @@ using shape_type = KokkosFFT::shape_type; template void fft1(ViewType& in, ViewType& out) { using value_type = typename ViewType::non_const_value_type; - using real_value_type = KokkosFFT::Impl::real_type_t; + using real_value_type = KokkosFFT::Impl::base_floating_point_type; static_assert(KokkosFFT::Impl::is_complex::value, "fft1: ViewType must be complex"); @@ -62,7 +62,7 @@ void fft1(ViewType& in, ViewType& out) { template void ifft1(ViewType& in, ViewType& out) { using value_type = typename ViewType::non_const_value_type; - using real_value_type = KokkosFFT::Impl::real_type_t; + using real_value_type = KokkosFFT::Impl::base_floating_point_type; static_assert(KokkosFFT::Impl::is_complex::value, "ifft1: ViewType must be complex");