Skip to content

Commit

Permalink
Merge pull request #26 from CExA-project/add-more-static-assertions
Browse files Browse the repository at this point in the history
Add more static assertions
  • Loading branch information
yasahi-hpc authored Jan 23, 2024
2 parents 231a82e + 189031a commit 15e4bad
Show file tree
Hide file tree
Showing 8 changed files with 1,061 additions and 163 deletions.
58 changes: 39 additions & 19 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ struct is_complex : std::false_type {};
template <typename T>
struct is_complex<Kokkos::complex<T>> : std::true_type {};

template <typename ViewType, typename Enable = void>
struct is_layout_left_or_right : std::false_type {};

template <typename ViewType>
struct is_layout_left_or_right<
ViewType,
std::enable_if_t<
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutLeft> ||
std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>>>
: std::true_type {};

template <typename ViewType>
inline constexpr bool is_layout_left_or_right_v =
is_layout_left_or_right<ViewType>::value;

template <typename ExecutionSpace, typename ViewType,
std::enable_if_t<ViewType::rank() == 1, std::nullptr_t> = nullptr>
struct complex_view_type {
Expand All @@ -40,9 +55,8 @@ struct complex_view_type {

template <typename ViewType>
auto convert_negative_axis(const ViewType& view, int _axis = -1) {
static_assert(
Kokkos::is_view<ViewType>::value,
"KokkosFFT::convert_negative_axis: ViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<ViewType>::value,
"convert_negative_axis: ViewType is not a Kokkos::View.");
int rank = static_cast<int>(ViewType::rank());
assert(abs(_axis) < rank); // axis should be in [-(rank-1), rank-1]
int axis = _axis < 0 ? rank + _axis : _axis;
Expand All @@ -51,9 +65,8 @@ auto convert_negative_axis(const ViewType& view, int _axis = -1) {

template <typename ViewType>
auto convert_negative_shift(const ViewType& view, int _shift, int _axis) {
static_assert(
Kokkos::is_view<ViewType>::value,
"KokkosFFT::convert_negative_shift: ViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<ViewType>::value,
"convert_negative_shift: ViewType is not a Kokkos::View.");
int axis = convert_negative_axis(view, _axis);
int extent = view.extent(axis);
int shift0 = 0, shift1 = 0, shift2 = extent / 2;
Expand Down Expand Up @@ -156,24 +169,18 @@ inline std::vector<ElementType> arange(const ElementType start,
template <typename ExecutionSpace, typename InViewType, typename OutViewType>
void conjugate(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out) {
static_assert(Kokkos::is_view<InViewType>::value,
"conjugate: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"conjugate: OutViewType is not a Kokkos::View.");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(KokkosFFT::Impl::is_complex<out_value_type>::value,
"KokkosFFT::Impl::conjugate: OutViewType must be complex");
"conjugate: OutViewType must be complex");
std::size_t size = in.size();

// [TO DO] is there a way to get device mirror?
if constexpr (InViewType::rank() == 1) {
out = OutViewType("out", in.extent(0));
} else if constexpr (InViewType::rank() == 2) {
out = OutViewType("out", in.extent(0), in.extent(1));
} else if constexpr (InViewType::rank() == 3) {
out = OutViewType("out", in.extent(0), in.extent(1), in.extent(2));
} else if constexpr (InViewType::rank() == 4) {
out = OutViewType("out", in.extent(0), in.extent(1), in.extent(2),
in.extent(3));
}
out = OutViewType("out", in.layout());

auto* in_data = in.data();
auto* out_data = out.data();
Expand All @@ -186,6 +193,19 @@ void conjugate(const ExecutionSpace& exec_space, const InViewType& in,
out_data[i] = Kokkos::conj(in_data[i]);
});
}

template <typename ViewType>
auto extract_extents(const ViewType& view) {
static_assert(Kokkos::is_view<ViewType>::value,
"extract_extents: ViewType is not a Kokkos::View.");
constexpr std::size_t rank = ViewType::rank();
std::array<std::size_t, rank> extents;
for (std::size_t i = 0; i < rank; i++) {
extents.at(i) = view.extent(i);
}
return extents;
}

} // namespace Impl
} // namespace KokkosFFT

Expand Down
40 changes: 40 additions & 0 deletions common/unit_test/Test_Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,44 @@ TEST(IsOutOfRangeValueIncluded, Array) {
EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 3));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(v, 4));
EXPECT_FALSE(KokkosFFT::Impl::is_out_of_range_value_included(v, 5));
}

TEST(ExtractExtents, 1Dto8D) {
using View1Dtype = Kokkos::View<double*, execution_space>;
using View2Dtype = Kokkos::View<double**, execution_space>;
using View3Dtype = Kokkos::View<double***, execution_space>;
using View4Dtype = Kokkos::View<double****, execution_space>;
using View5Dtype = Kokkos::View<double*****, execution_space>;
using View6Dtype = Kokkos::View<double******, execution_space>;
using View7Dtype = Kokkos::View<double*******, execution_space>;
using View8Dtype = Kokkos::View<double********, execution_space>;

std::size_t n1 = 1, n2 = 1, n3 = 2, n4 = 3, n5 = 5, n6 = 8, n7 = 13, n8 = 21;

std::array<std::size_t, 1> ref_extents1D = {n1};
std::array<std::size_t, 2> ref_extents2D = {n1, n2};
std::array<std::size_t, 3> ref_extents3D = {n1, n2, n3};
std::array<std::size_t, 4> ref_extents4D = {n1, n2, n3, n4};
std::array<std::size_t, 5> ref_extents5D = {n1, n2, n3, n4, n5};
std::array<std::size_t, 6> ref_extents6D = {n1, n2, n3, n4, n5, n6};
std::array<std::size_t, 7> ref_extents7D = {n1, n2, n3, n4, n5, n6, n7};
std::array<std::size_t, 8> ref_extents8D = {n1, n2, n3, n4, n5, n6, n7, n8};

View1Dtype view1D("view1D", n1);
View2Dtype view2D("view2D", n1, n2);
View3Dtype view3D("view3D", n1, n2, n3);
View4Dtype view4D("view4D", n1, n2, n3, n4);
View5Dtype view5D("view5D", n1, n2, n3, n4, n5);
View6Dtype view6D("view6D", n1, n2, n3, n4, n5, n6);
View7Dtype view7D("view7D", n1, n2, n3, n4, n5, n6, n7);
View8Dtype view8D("view8D", n1, n2, n3, n4, n5, n6, n7, n8);

EXPECT_EQ(KokkosFFT::Impl::extract_extents(view1D), ref_extents1D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view2D), ref_extents2D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view3D), ref_extents3D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view4D), ref_extents4D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view5D), ref_extents5D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view6D), ref_extents6D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view7D), ref_extents7D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view8D), ref_extents8D);
}
15 changes: 15 additions & 0 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
cufftResult cufft_rt = cufftCreate(&plan);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream(plan, stream);

const int batch = 1;
const int axis = 0;

Expand Down Expand Up @@ -61,6 +64,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
cufftResult cufft_rt = cufftCreate(&plan);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream(plan, stream);

const int axis = 0;
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
Expand Down Expand Up @@ -94,6 +100,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
cufftResult cufft_rt = cufftCreate(&plan);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream(plan, stream);

const int axis = 0;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
Expand Down Expand Up @@ -130,6 +139,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
cufftResult cufft_rt = cufftCreate(&plan);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream(plan, stream);

const int rank = InViewType::rank();
const int batch = 1;
const int axis = 0;
Expand Down Expand Up @@ -188,6 +200,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
cufftResult cufft_rt = cufftCreate(&plan);
if (cufft_rt != CUFFT_SUCCESS) throw std::runtime_error("cufftCreate failed");

cudaStream_t stream = exec_space.cuda_stream();
cufftSetStream(plan, stream);

cufft_rt =
cufftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(), istride,
idist, out_extents.data(), ostride, odist, type, howmany);
Expand Down
15 changes: 15 additions & 0 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftSetStream(plan, stream);

const int batch = 1;
const int axis = 0;

Expand Down Expand Up @@ -64,6 +67,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftSetStream(plan, stream);

const int axis = 0;
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
Expand Down Expand Up @@ -99,6 +105,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftSetStream(plan, stream);

const int batch = 1;
const int axis = 0;

Expand Down Expand Up @@ -138,6 +147,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftSetStream(plan, stream);

const int rank = InViewType::rank();
const int batch = 1;
const int axis = 0;
Expand Down Expand Up @@ -199,6 +211,9 @@ auto _create(const ExecutionSpace& exec_space, PlanType& plan,
if (hipfft_rt != HIPFFT_SUCCESS)
throw std::runtime_error("hipfftCreate failed");

hipStream_t stream = exec_space.hip_stream();
hipfftSetStream(plan, stream);

hipfft_rt = hipfftPlanMany(&plan, rank, fft_extents.data(), in_extents.data(),
istride, idist, out_extents.data(), ostride, odist,
type, howmany);
Expand Down
36 changes: 27 additions & 9 deletions fft/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ template <typename ViewType, std::size_t DIM = 1>
auto _get_shift(const ViewType& inout, axis_type<DIM> _axes,
int direction = 1) {
static_assert(DIM > 0,
"KokkosFFT::Impl::_get_shift: Rank of shift axes must be "
"_get_shift: Rank of shift axes must be "
"larger than or equal to 1.");

// Convert the input axes to be in the range of [0, rank-1]
Expand All @@ -37,8 +37,7 @@ auto _get_shift(const ViewType& inout, axis_type<DIM> _axes,
template <typename ExecutionSpace, typename ViewType>
void _roll(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<1> shift, axis_type<1> axes) {
static_assert(ViewType::rank() == 1,
"KokkosFFT::Impl::_roll: Rank of View must be 1.");
static_assert(ViewType::rank() == 1, "_roll: Rank of View must be 1.");
std::size_t n0 = inout.extent(0);

ViewType tmp("tmp", n0);
Expand Down Expand Up @@ -68,8 +67,7 @@ template <typename ExecutionSpace, typename ViewType, std::size_t DIM1 = 1>
void _roll(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<2> shift, axis_type<DIM1> axes) {
constexpr std::size_t DIM0 = 2;
static_assert(ViewType::rank() == DIM0,
"KokkosFFT::Impl::_roll: Rank of View must be 2.");
static_assert(ViewType::rank() == DIM0, "_roll: Rank of View must be 2.");
int n0 = inout.extent(0), n1 = inout.extent(1);

ViewType tmp("tmp", n0, n1);
Expand Down Expand Up @@ -129,8 +127,18 @@ void _roll(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void _fftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"_fftshift: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"_fftshift: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"_fftshift: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"KokkosFFT::Impl::_fftshift: Rank of View must be larger thane "
"_fftshift: Rank of View must be larger thane "
"or equal to the Rank of shift axes.");
auto shift = _get_shift(inout, axes);
_roll(exec_space, inout, shift, axes);
Expand All @@ -139,8 +147,18 @@ void _fftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void _ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"_ifftshift: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"_ifftshift: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"_ifftshift: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"KokkosFFT::Impl::_ifftshift: Rank of View must be larger "
"_ifftshift: Rank of View must be larger "
"thane or equal to the Rank of shift axes.");
auto shift = _get_shift(inout, axes, -1);
_roll(exec_space, inout, shift, axes);
Expand All @@ -153,7 +171,7 @@ template <typename ExecutionSpace, typename RealType>
auto fftfreq(const ExecutionSpace& exec_space, const std::size_t n,
const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
"KokkosFFT::fftfreq: d must be real");
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
ViewType freq("freq", n);

Expand Down Expand Up @@ -181,7 +199,7 @@ template <typename ExecutionSpace, typename RealType>
auto rfftfreq(const ExecutionSpace& exec_space, const std::size_t n,
const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
"KokkosFFT::fftfreq: d must be real");
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;

RealType val = 1.0 / (static_cast<RealType>(n) * d);
Expand Down
Loading

0 comments on commit 15e4bad

Please sign in to comment.