Skip to content

Commit

Permalink
Add traits to help type checks (#117)
Browse files Browse the repository at this point in the history
* Add traits and tests

* remove half types from tests

* rename real_type to base_floating_point

* Add docstrings for traits

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jul 12, 2024
1 parent c09d8e3 commit 0e29048
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 79 deletions.
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ void normalize_impl(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ViewType>
auto get_coefficients(ViewType, Direction direction,
Normalization normalization, std::size_t fft_size) {
using value_type =
KokkosFFT::Impl::real_type_t<typename ViewType::non_const_value_type>;
using value_type = KokkosFFT::Impl::base_floating_point_type<
typename ViewType::non_const_value_type>;
value_type coef = 1;
[[maybe_unused]] bool to_normalize = false;

Expand Down
134 changes: 134 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#ifndef KOKKOSFFT_TRAITS_HPP
#define KOKKOSFFT_TRAITS_HPP

#include <Kokkos_Core.hpp>

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct base_floating_point {
using value_type = T;
};

template <typename T>
struct base_floating_point<Kokkos::complex<T>> {
using value_type = T;
};

/// \brief Helper to extract the base floating point type from a complex type
template <typename T>
using base_floating_point_type = typename base_floating_point<T>::value_type;

template <typename T, typename Enable = void>
struct is_real : std::false_type {};

template <typename T>
struct is_real<
T, std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable real type (float/double)
/// for Kokkos-FFT
template <typename T>
inline constexpr bool is_real_v = is_real<T>::value;

template <typename T, typename Enable = void>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<
Kokkos::complex<T>,
std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable complex type
/// (Kokkos::complex<float>/Kokkos::complex<double>) for Kokkos-FFT
template <typename T>
inline constexpr bool is_complex_v = is_complex<T>::value;

// is value type admissible for KokkosFFT
template <typename T, typename Enable = void>
struct is_admissible_value_type : std::false_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<is_real_v<T> || is_complex_v<T>>> : std::true_type {};

template <typename T>
struct is_admissible_value_type<
T, std::enable_if_t<Kokkos::is_view<T>::value &&
(is_real_v<typename T::non_const_value_type> ||
is_complex_v<typename T::non_const_value_type>)>>
: std::true_type {};

/// \brief Helper to check if a type is an acceptable value type
/// (float/double/Kokkos::complex<float>/Kokkos::complex<double>) for Kokkos-FFT
/// When applied to Kokkos::View, then check if a value type is an
/// acceptable real/complex type.
template <typename T>
inline constexpr bool is_admissible_value_type_v =
is_admissible_value_type<T>::value;

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
Kokkos::is_view<ViewType>::value &&
(std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>)>>
: std::true_type {};

/// \brief Helper to check if a View layout is an acceptable layout type
/// (Kokkos::LayoutLeft/Kokkos::LayoutRight) for Kokkos-FFT
template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ViewType, typename Enable = void>
struct is_admissible_view : std::false_type {};

template <typename ViewType>
struct is_admissible_view<
ViewType, std::enable_if_t<Kokkos::is_view<ViewType>::value &&
is_layout_left_or_right_v<ViewType> &&
is_admissible_value_type_v<ViewType>>>
: std::true_type {};

/// \brief Helper to check if a View is an acceptable for Kokkos-FFT. Values and
/// layout are checked
template <typename ViewType>
inline constexpr bool is_admissible_view_v =
is_admissible_view<ViewType>::value;

/// \brief Helper to define a managable View type from the original view type
template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

/// \brief Helper to define a complex 1D View type from a real/complex 1D View
/// type, while keeping other properties
template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::base_floating_point_type<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

} // namespace Impl
} // namespace KokkosFFT

#endif
52 changes: 1 addition & 51 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,10 @@
#include <set>
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename T>
struct real_type {
using type = T;
};

template <typename T>
struct real_type<Kokkos::complex<T>> {
using type = T;
};

template <typename T>
struct managable_view_type {
using type = Kokkos::View<typename T::data_type, typename T::array_layout,
typename T::memory_space,
Kokkos::MemoryTraits<T::memory_traits::impl_value &
~unsigned(Kokkos::Unmanaged)>>;
};

template <typename T>
using real_type_t = typename real_type<T>::type;

template <typename T>
struct is_complex : std::false_type {};

template <typename T>
struct is_complex<Kokkos::complex<T>> : std::true_type {};

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
using value_type = typename ViewType::non_const_value_type;
using float_type = KokkosFFT::Impl::real_type_t<value_type>;
using complex_type = Kokkos::complex<float_type>;
using array_layout_type = typename ViewType::array_layout;
using type = Kokkos::View<complex_type*, array_layout_type, ExecutionSpace>;
};

template <typename ViewType>
auto convert_negative_axis(ViewType, int _axis = -1) {
Expand Down
3 changes: 2 additions & 1 deletion common/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
add_executable(unit-tests-kokkos-fft-common
Test_Main.cpp
Test_Utils.cpp
Test_Traits.cpp
Test_Normalization.cpp
Test_Transpose.cpp
Test_Layouts.cpp
Expand All @@ -20,4 +21,4 @@ target_link_libraries(unit-tests-kokkos-fft-common PUBLIC common GTest::gtest)

# Enable GoogleTest
include(GoogleTest)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
gtest_discover_tests(unit-tests-kokkos-fft-common PROPERTIES DISCOVERY_TIMEOUT 600)
170 changes: 170 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#include <gtest/gtest.h>
#include "KokkosFFT_traits.hpp"
#include "Test_Utils.hpp"

using real_types = ::testing::Types<float, double, long double>;
using view_types =
::testing::Types<std::pair<float, Kokkos::LayoutLeft>,
std::pair<float, Kokkos::LayoutRight>,
std::pair<float, Kokkos::LayoutStride>,
std::pair<double, Kokkos::LayoutLeft>,
std::pair<double, Kokkos::LayoutRight>,
std::pair<double, Kokkos::LayoutStride>,
std::pair<long double, Kokkos::LayoutLeft>,
std::pair<long double, Kokkos::LayoutRight>,
std::pair<long double, Kokkos::LayoutStride>>;

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
using complex_type = Kokkos::complex<T>;
};

template <typename T>
struct RealAndComplexViewTypes : public ::testing::Test {
using real_type = typename T::first_type;
using complex_type = Kokkos::complex<real_type>;
using layout_type = typename T::second_type;
};

TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
using real_type_from_RealType =
KokkosFFT::Impl::base_floating_point_type<RealType>;
using real_type_from_ComplexType =
KokkosFFT::Impl::base_floating_point_type<ComplexType>;

static_assert(std::is_same_v<real_type_from_RealType, RealType>,
"Real type not deduced correctly from real type");
static_assert(std::is_same_v<real_type_from_ComplexType, RealType>,
"Real type not deduced correctly from complex type");
}

// Tests for admissible real types (float or double)
template <typename T>
void test_admissible_real_type() {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
static_assert(KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_real_v<T>,
"Real type must be float or double");
}
}

template <typename T>
void test_admissible_complex_type() {
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
} else {
static_assert(!KokkosFFT::Impl::is_complex_v<T>,
"Complex type must be Kokkos::complex<float> or "
"Kokkos::complex<double>");
}
}

TYPED_TEST(RealAndComplexTypes, get_real_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;

test_get_real_type<real_type, complex_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_real_type) {
using real_type = typename TestFixture::real_type;

test_admissible_real_type<real_type>();
}

TYPED_TEST(RealAndComplexTypes, admissible_complex_type) {
using complex_type = typename TestFixture::complex_type;

test_admissible_complex_type<complex_type>();
}

// Tests for admissible view types
template <typename T, typename LayoutType>
void test_admissible_value_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (std::is_same_v<real_type, float> ||
std::is_same_v<real_type, double>) {
static_assert(KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_value_type_v<ViewType>,
"Real type must be float or double");
}
}

template <typename T, typename LayoutType>
void test_admissible_layout_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>) {
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"View Layout must be either LayoutLeft or LayoutRight.");
}
}

template <typename T, typename LayoutType>
void test_admissible_view_type() {
using ViewType = Kokkos::View<T*, LayoutType>;
using real_type = KokkosFFT::Impl::base_floating_point_type<T>;
if constexpr (
(std::is_same_v<real_type, float> || std::is_same_v<real_type, double>)&&(
std::is_same_v<LayoutType, Kokkos::LayoutLeft> ||
std::is_same_v<LayoutType, Kokkos::LayoutRight>)) {
static_assert(KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
} else {
static_assert(!KokkosFFT::Impl::is_admissible_view_v<ViewType>,
"View value type must be float, double, "
"Kokkos::Complex<float>, Kokkos::Complex<double>. Layout "
"must be either LayoutLeft or LayoutRight.");
}
}

TYPED_TEST(RealAndComplexViewTypes, admissible_value_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_value_type<real_type, layout_type>();
test_admissible_value_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_layout_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_layout_type<real_type, layout_type>();
test_admissible_layout_type<complex_type, layout_type>();
}

TYPED_TEST(RealAndComplexViewTypes, admissible_view_type) {
using real_type = typename TestFixture::real_type;
using complex_type = typename TestFixture::complex_type;
using layout_type = typename TestFixture::layout_type;

test_admissible_view_type<real_type, layout_type>();
test_admissible_view_type<complex_type, layout_type>();
}
Loading

0 comments on commit 0e29048

Please sign in to comment.