Skip to content

Commit

Permalink
Merge pull request #51 from CExA-project/add-high-dimensionality-support
Browse files Browse the repository at this point in the history
Add high dimensionality support
  • Loading branch information
yasahi-hpc authored Feb 15, 2024
2 parents caf52c1 + cf75649 commit 36a5df2
Show file tree
Hide file tree
Showing 6 changed files with 2,059 additions and 173 deletions.
194 changes: 178 additions & 16 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_modified_shape(const ViewType& view, shape_type<DIM> shape) {
static_assert(ViewType::rank() >= DIM,
"KokkosFFT::get_modified_shape: Rank of View must be larger "
"get_modified_shape: Rank of View must be larger "
"than or equal to the Rank of new shape");
static_assert(DIM > 0,
"KokkosFFT::get_modified_shape: Rank of FFT axes must be "
"get_modified_shape: Rank of FFT axes must be "
"larger than or equal to 1");

// [TO DO] Add a is_C2R arg. If is_C2R is true, then shape should be shape/2+1
Expand Down Expand Up @@ -41,7 +41,7 @@ template <typename ViewType, std::size_t DIM>
auto is_crop_or_pad_needed(const ViewType& view,
const shape_type<DIM>& modified_shape) {
static_assert(ViewType::rank() == DIM,
"KokkosFFT::_crop_or_pad: Rank of View must be equal to Rank "
"is_crop_or_pad_needed: Rank of View must be equal to Rank "
"of extended shape.");

// [TO DO] Add a is_C2R arg. If is_C2R is true, then shape should be shape/2+1
Expand All @@ -61,11 +61,6 @@ auto is_crop_or_pad_needed(const ViewType& view,
template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<1> s) {
constexpr std::size_t DIM = 1;
static_assert(ViewType::rank() == DIM,
"KokkosFFT::_crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");

auto _n0 = s.at(0);
out = ViewType("out", _n0);

Expand All @@ -81,9 +76,6 @@ template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<2> s) {
constexpr std::size_t DIM = 2;
static_assert(ViewType::rank() == DIM,
"KokkosFFT::_crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");

auto [_n0, _n1] = s;
out = ViewType("out", _n0, _n1);
Expand All @@ -93,7 +85,7 @@ void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<2, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

Expand All @@ -109,9 +101,6 @@ template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<3> s) {
constexpr std::size_t DIM = 3;
static_assert(ViewType::rank() == DIM,
"KokkosFFT::_crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");

auto [_n0, _n1, _n2] = s;
out = ViewType("out", _n0, _n1, _n2);
Expand All @@ -122,7 +111,7 @@ void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<3, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

Expand All @@ -137,9 +126,182 @@ void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
});
}

template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<4> s) {
constexpr std::size_t DIM = 4;

auto [_n0, _n1, _n2, _n3] = s;
out = ViewType("out", _n0, _n1, _n2, _n3);

int n0 = std::min(_n0, in.extent(0));
int n1 = std::min(_n1, in.extent(1));
int n2 = std::min(_n2, in.extent(2));
int n3 = std::min(_n3, in.extent(3));

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(point_type{{0, 0, 0, 0}}, point_type{{n0, n1, n2, n3}},
tile_type{{4, 4, 4, 4}}
// [TO DO] Choose optimal tile sizes for each device
);

Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1, int i2, int i3) {
out(i0, i1, i2, i3) = in(i0, i1, i2, i3);
});
}

template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<5> s) {
constexpr std::size_t DIM = 5;

auto [_n0, _n1, _n2, _n3, _n4] = s;
out = ViewType("out", _n0, _n1, _n2, _n3, _n4);

int n0 = std::min(_n0, in.extent(0));
int n1 = std::min(_n1, in.extent(1));
int n2 = std::min(_n2, in.extent(2));
int n3 = std::min(_n3, in.extent(3));
int n4 = std::min(_n4, in.extent(4));

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(point_type{{0, 0, 0, 0, 0}},
point_type{{n0, n1, n2, n3, n4}}, tile_type{{4, 4, 4, 4, 1}}
// [TO DO] Choose optimal tile sizes for each device
);

Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1, int i2, int i3, int i4) {
out(i0, i1, i2, i3, i4) = in(i0, i1, i2, i3, i4);
});
}

template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<6> s) {
constexpr std::size_t DIM = 6;

auto [_n0, _n1, _n2, _n3, _n4, _n5] = s;
out = ViewType("out", _n0, _n1, _n2, _n3, _n4, _n5);

int n0 = std::min(_n0, in.extent(0));
int n1 = std::min(_n1, in.extent(1));
int n2 = std::min(_n2, in.extent(2));
int n3 = std::min(_n3, in.extent(3));
int n4 = std::min(_n4, in.extent(4));
int n5 = std::min(_n5, in.extent(5));

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(point_type{{0, 0, 0, 0, 0, 0}},
point_type{{n0, n1, n2, n3, n4, n5}},
tile_type{{4, 4, 4, 4, 1, 1}}
// [TO DO] Choose optimal tile sizes for each device
);

Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1, int i2, int i3, int i4, int i5) {
out(i0, i1, i2, i3, i4, i5) = in(i0, i1, i2, i3, i4, i5);
});
}

template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<7> s) {
constexpr std::size_t DIM = 6;

auto [_n0, _n1, _n2, _n3, _n4, _n5, _n6] = s;
out = ViewType("out", _n0, _n1, _n2, _n3, _n4, _n5, _n6);

int n0 = std::min(_n0, in.extent(0));
int n1 = std::min(_n1, in.extent(1));
int n2 = std::min(_n2, in.extent(2));
int n3 = std::min(_n3, in.extent(3));
int n4 = std::min(_n4, in.extent(4));
int n5 = std::min(_n5, in.extent(5));
int n6 = std::min(_n6, in.extent(6));

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(point_type{{0, 0, 0, 0, 0, 0}},
point_type{{n0, n1, n2, n3, n4, n5}},
tile_type{{4, 4, 4, 4, 1, 1}}
// [TO DO] Choose optimal tile sizes for each device
);

Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1, int i2, int i3, int i4, int i5) {
for (int i6 = 0; i6 < n6; i6++) {
out(i0, i1, i2, i3, i4, i5, i6) = in(i0, i1, i2, i3, i4, i5, i6);
}
});
}

template <typename ExecutionSpace, typename ViewType>
void _crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<8> s) {
constexpr std::size_t DIM = 6;

auto [_n0, _n1, _n2, _n3, _n4, _n5, _n6, _n7] = s;
out = ViewType("out", _n0, _n1, _n2, _n3, _n4, _n5, _n6, _n7);

int n0 = std::min(_n0, in.extent(0));
int n1 = std::min(_n1, in.extent(1));
int n2 = std::min(_n2, in.extent(2));
int n3 = std::min(_n3, in.extent(3));
int n4 = std::min(_n4, in.extent(4));
int n5 = std::min(_n5, in.extent(5));
int n6 = std::min(_n6, in.extent(6));
int n7 = std::min(_n7, in.extent(7));

using range_type = Kokkos::MDRangePolicy<
ExecutionSpace,
Kokkos::Rank<DIM, Kokkos::Iterate::Default, Kokkos::Iterate::Default>>;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(point_type{{0, 0, 0, 0, 0, 0}},
point_type{{n0, n1, n2, n3, n4, n5}},
tile_type{{4, 4, 4, 4, 1, 1}}
// [TO DO] Choose optimal tile sizes for each device
);

Kokkos::parallel_for(
range, KOKKOS_LAMBDA(int i0, int i1, int i2, int i3, int i4, int i5) {
for (int i6 = 0; i6 < n6; i6++) {
for (int i7 = 0; i7 < n7; i7++) {
out(i0, i1, i2, i3, i4, i5, i6, i7) =
in(i0, i1, i2, i3, i4, i5, i6, i7);
}
}
});
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void crop_or_pad(const ExecutionSpace& exec_space, const ViewType& in,
ViewType& out, shape_type<DIM> s) {
static_assert(ViewType::rank() == DIM,
"crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");
_crop_or_pad(exec_space, in, out, s);
}
} // namespace Impl
Expand Down
Loading

0 comments on commit 36a5df2

Please sign in to comment.