Skip to content

Commit

Permalink
Merge pull request #94 from kokkos/refine-fft-impl
Browse files Browse the repository at this point in the history
Use exec space from Plan in _fft function
  • Loading branch information
yasahi-hpc authored Apr 6, 2024
2 parents a322cb5 + 3e31d42 commit 60bfd7d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 84 deletions.
17 changes: 14 additions & 3 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace Impl {
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
class Plan {
public:
//! The type of Kokkos execution pace
using execSpace = ExecutionSpace;

Expand Down Expand Up @@ -97,6 +98,10 @@ class Plan {
//! The type of extents of input/output views
using extents_type = shape_type<InViewType::rank()>;

private:
//! Execution space
execSpace m_exec_space;

//! Dynamically allocatable fft plan.
std::unique_ptr<fft_plan_type> m_plan;

Expand Down Expand Up @@ -148,7 +153,8 @@ class Plan {
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction, int axis,
std::optional<std::size_t> n = std::nullopt)
: m_fft_size(1),
: m_exec_space(exec_space),
m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
m_axes({axis}) {
Expand Down Expand Up @@ -200,7 +206,8 @@ class Plan {
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction,
axis_type<DIM> axes, shape_type<DIM> s = {0})
: m_fft_size(1),
: m_exec_space(exec_space),
m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
m_axes(axes) {
Expand Down Expand Up @@ -311,14 +318,18 @@ class Plan {
}
}

/// \brief Return the execution space
execSpace const& exec_space() const noexcept { return m_exec_space; }

/// \brief Return the FFT plan
fft_plan_type& plan() const { return *m_plan; }

/// \brief Return the FFT info
const fft_info_type& info() const { return m_info; }
fft_info_type const& info() const { return m_info; }

/// \brief Return the FFT size
fft_size_type fft_size() const { return m_fft_size; }
KokkosFFT::Direction direction() const { return m_direction; }
bool is_transpose_needed() const { return m_is_transpose_needed; }
map_type map() const { return m_map; }
map_type map_inv() const { return m_map_inv; }
Expand Down
117 changes: 36 additions & 81 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@
// General Transform Interface
namespace KokkosFFT {
namespace Impl {
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType>
void _fft(const ExecutionSpace& exec_space, PlanType& plan,
const InViewType& in, OutViewType& out,
template <typename PlanType, typename InViewType, typename OutViewType>
void _fft(const PlanType& plan, const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
static_assert(Kokkos::is_view<InViewType>::value,
"_fft: InViewType is not a Kokkos::View.");
Expand All @@ -65,6 +63,7 @@ void _fft(const ExecutionSpace& exec_space, PlanType& plan,
"_fft: InViewType and OutViewType must have "
"the same Layout.");

using ExecutionSpace = typename PlanType::execSpace;
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename InViewType::memory_space>::accessible,
Expand All @@ -82,57 +81,13 @@ void _fft(const ExecutionSpace& exec_space, PlanType& plan,
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, out_value_type>::type*>(out.data());

auto forward = direction_type<ExecutionSpace>(KokkosFFT::Direction::forward);
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, forward, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, KokkosFFT::Direction::forward,
norm, plan.fft_size());
auto const exec_space = plan.exec_space();
auto const fft_direction = direction_type<ExecutionSpace>(plan.direction());
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, fft_direction, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, plan.direction(), norm,
plan.fft_size());
}

template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType>
void _ifft(const ExecutionSpace& exec_space, PlanType& plan,
const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
static_assert(Kokkos::is_view<InViewType>::value,
"_ifft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"_ifft: OutViewType is not a Kokkos::View.");
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"_ifft: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"_ifft: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"_ifft: InViewType and OutViewType must have "
"the same rank.");
static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"_ifft: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename InViewType::memory_space>::accessible,
"_ifft: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"_ifft: execution_space cannot access data in OutViewType");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, out_value_type>::type*>(out.data());

auto backward =
direction_type<ExecutionSpace>(KokkosFFT::Direction::backward);
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, backward, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, KokkosFFT::Direction::backward,
norm, plan.fft_size());
}
} // namespace Impl
} // namespace KokkosFFT

Expand Down Expand Up @@ -198,12 +153,12 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -271,12 +226,12 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -348,12 +303,12 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -427,12 +382,12 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -914,11 +869,11 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -985,11 +940,11 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1057,11 +1012,11 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1132,11 +1087,11 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1409,11 +1364,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1478,11 +1433,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1549,11 +1504,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1622,11 +1577,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1695,11 +1650,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1770,11 +1725,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down

0 comments on commit 60bfd7d

Please sign in to comment.