Skip to content

Commit

Permalink
Merge pull request #96 from kokkos/create-plan-with-shape-arg
Browse files Browse the repository at this point in the history
Create plan with shape arg
  • Loading branch information
yasahi-hpc authored Apr 15, 2024
2 parents b15bc9f + 87c0bff commit 8d822ea
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 187 deletions.
28 changes: 19 additions & 9 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_utils.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"

namespace KokkosFFT {
namespace Impl {
Expand All @@ -20,14 +21,11 @@ namespace Impl {
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents(const InViewType& in, const OutViewType& out,
axis_type<DIM> _axes) {
axis_type<DIM> axes, shape_type<DIM> shape = {0}) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

static_assert(InViewType::rank() >= DIM,
"KokkosFFT::get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
Expand All @@ -41,20 +39,32 @@ auto get_extents(const InViewType& in, const OutViewType& out,
? 0
: (rank - 1);

std::vector<int> _in_extents, _out_extents, _fft_extents;
// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, axes);

// Get new shape based on shape parameter
// [TO DO] get_modified shape should take out as well and check is_C2R
// internally
bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;
auto modified_in_shape =
KokkosFFT::Impl::get_modified_shape(in, shape, axes, is_C2R);

// Get extents for the inner most axes in LayoutRight
// If we allow the FFT on the layoutLeft, this part should be modified
std::vector<int> _in_extents, _out_extents, _fft_extents;
for (std::size_t i = 0; i < rank; i++) {
auto _idx = map.at(i);
_in_extents.push_back(in.extent(_idx));
_out_extents.push_back(out.extent(_idx));
auto _idx = map.at(i);
auto in_extent = modified_in_shape.at(_idx);
auto out_extent = out.extent(_idx);
_in_extents.push_back(in_extent);
_out_extents.push_back(out_extent);

// The extent for transform is always equal to the extent
// of the extent of real type (R2C or C2R)
// For C2C, the in and out extents are the same.
// In the end, we can just use the largest extent among in and out extents.
auto fft_extent = std::max(in.extent(_idx), out.extent(_idx));
auto fft_extent = std::max(in_extent, out_extent);
_fft_extents.push_back(fft_extent);
}

Expand Down
5 changes: 5 additions & 0 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
"larger than or equal to 1");
constexpr int rank = static_cast<int>(ViewType::rank());

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
return KokkosFFT::Impl::extract_extents(view);
}

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> positive_axes;
for (std::size_t i = 0; i < DIM; i++) {
Expand Down
20 changes: 12 additions & 8 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<1> axes) {
[[maybe_unused]] Direction direction, axis_type<1> axes,
shape_type<1> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -39,7 +40,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -59,7 +60,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<2> axes) {
[[maybe_unused]] Direction direction, axis_type<2> axes,
shape_type<2> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -77,7 +79,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -97,7 +99,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<3> axes) {
[[maybe_unused]] Direction direction, axis_type<3> axes,
shape_type<3> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -115,7 +118,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);

const int nx = fft_extents.at(0), ny = fft_extents.at(1),
nz = fft_extents.at(2);
Expand All @@ -137,7 +140,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -153,7 +157,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
20 changes: 12 additions & 8 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<1> axes) {
[[maybe_unused]] Direction direction, axis_type<1> axes,
shape_type<1> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -40,7 +41,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -61,7 +62,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<2> axes) {
[[maybe_unused]] Direction direction, axis_type<2> axes,
shape_type<2> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -80,7 +82,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -101,7 +103,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<3> axes) {
[[maybe_unused]] Direction direction, axis_type<3> axes,
shape_type<3> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -120,7 +123,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);

const int nx = fft_extents.at(0), ny = fft_extents.at(1),
nz = fft_extents.at(2);
Expand All @@ -143,7 +146,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -159,7 +163,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_OpenMP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -57,7 +58,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
30 changes: 27 additions & 3 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"

#if defined(KOKKOS_ENABLE_CUDA)
Expand Down Expand Up @@ -117,11 +118,14 @@ class Plan {
//! whether transpose is needed or not
bool m_is_transpose_needed;

//! whether crop or pad is needed or not
bool m_is_crop_or_pad_needed;

//! axes for fft
axis_type<DIM> m_axes;

//! Shape of the transformed axis of the output
shape_type<DIM> m_shape;
extents_type m_shape;

//! directions of fft
KokkosFFT::Direction m_direction;
Expand Down Expand Up @@ -186,12 +190,24 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

shape_type<1> s = {0};
if (n) {
std::size_t _n = n.value();
s = shape_type<1>({_n});
}

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;

m_in_extents = KokkosFFT::Impl::extract_extents(in);
m_out_extents = KokkosFFT::Impl::extract_extents(out);
std::tie(m_map, m_map_inv) = KokkosFFT::Impl::get_map_axes(in, axis);
m_is_transpose_needed = KokkosFFT::Impl::is_transpose_needed(m_map);
m_shape = KokkosFFT::Impl::get_modified_shape(in, s, m_axes, is_C2R);
m_is_crop_or_pad_needed =
KokkosFFT::Impl::is_crop_or_pad_needed(in, m_shape);
m_fft_size = KokkosFFT::Impl::_create(exec_space, m_plan, in, out, m_buffer,
m_info, direction, m_axes);
m_info, direction, m_axes, s);
}

/// \brief Constructor for multidimensional FFT
Expand Down Expand Up @@ -240,12 +256,18 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;

m_in_extents = KokkosFFT::Impl::extract_extents(in);
m_out_extents = KokkosFFT::Impl::extract_extents(out);
std::tie(m_map, m_map_inv) = KokkosFFT::Impl::get_map_axes(in, axes);
m_is_transpose_needed = KokkosFFT::Impl::is_transpose_needed(m_map);
m_shape = KokkosFFT::Impl::get_modified_shape(in, s, m_axes, is_C2R);
m_is_crop_or_pad_needed =
KokkosFFT::Impl::is_crop_or_pad_needed(in, m_shape);
m_fft_size = KokkosFFT::Impl::_create(exec_space, m_plan, in, out, m_buffer,
m_info, direction, axes);
m_info, direction, axes, s);
}

~Plan() {
Expand Down Expand Up @@ -331,6 +353,8 @@ class Plan {
fft_size_type fft_size() const { return m_fft_size; }
KokkosFFT::Direction direction() const { return m_direction; }
bool is_transpose_needed() const { return m_is_transpose_needed; }
bool is_crop_or_pad_needed() const { return m_is_crop_or_pad_needed; }
extents_type shape() const { return m_shape; }
map_type map() const { return m_map; }
map_type map_inv() const { return m_map_inv; }
nonConstInViewType& in_T() { return m_in_T; }
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_ROCM_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
BufferViewType& buffer, InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -105,7 +106,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_SYCL_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -69,7 +70,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
Loading

0 comments on commit 8d822ea

Please sign in to comment.