Skip to content

Commit

Permalink
For padding capability, use the output type to compute the modified e…
Browse files Browse the repository at this point in the history
…xtents (#100)

* For padding capability, use the output type to compute the modified extents

* removed unused lines

* remove unused is_C2R

* formatting

* cleanup: get_modified_shape function and tests

* formatting

* fix: docstring for get_modified_shape

* fix: recover missing test and remove static_cast

* formatting

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jun 26, 2024
1 parent 6b2ccb2 commit cf35a43
Show file tree
Hide file tree
Showing 4 changed files with 385 additions and 255 deletions.
6 changes: 1 addition & 5 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,8 @@ auto get_extents(const InViewType& in, const OutViewType& out,
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, axes);

// Get new shape based on shape parameter
// [TO DO] get_modified shape should take out as well and check is_C2R
// internally
bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;
auto modified_in_shape =
KokkosFFT::Impl::get_modified_shape(in, shape, axes, is_C2R);
KokkosFFT::Impl::get_modified_shape(in, out, shape, axes);

// Get extents for the inner most axes in LayoutRight
// If we allow the FFT on the layoutLeft, this part should be modified
Expand Down
40 changes: 31 additions & 9 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,42 @@

namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
axis_type<DIM> axes, bool is_C2R = false) {
static_assert(ViewType::rank() >= DIM,
"get_modified_shape: Rank of View must be larger "

/// \brief Return a new shape of the input view based on the
/// specified input shape and axes.
///
/// \tparam InViewType The input view type
/// \tparam OutViewType The output view type
/// \tparam DIM The dimensionality of the shape and axes
///
/// \param in [in] Input view from which to derive the new shape
/// \param out [in] Output view (unused but necessary for type deduction)
/// \param shape [in] The new shape of the input view. If the shape is zero,
/// no modifications are made.
/// \param axes [in] Axes over which the shape modification is applied.
template <typename InViewType, typename OutViewType, std::size_t DIM>
auto get_modified_shape(const InViewType in, const OutViewType /* out */,
shape_type<DIM> shape, axis_type<DIM> axes) {
static_assert(InViewType::rank() >= DIM,
"get_modified_shape: Rank of Input View must be larger "
"than or equal to the Rank of new shape");
static_assert(OutViewType::rank() >= DIM,
"get_modified_shape: Rank of Output View must be larger "
"than or equal to the Rank of new shape");
static_assert(DIM > 0,
"get_modified_shape: Rank of FFT axes must be "
"larger than or equal to 1");
constexpr int rank = static_cast<int>(ViewType::rank());
constexpr int rank = static_cast<int>(InViewType::rank());

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
return KokkosFFT::Impl::extract_extents(view);
return KokkosFFT::Impl::extract_extents(in);
}

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> positive_axes;
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(view, axes.at(i));
int axis = KokkosFFT::Impl::convert_negative_axis(in, axes.at(i));
positive_axes.push_back(axis);
}

Expand All @@ -41,7 +57,7 @@ auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
using full_shape_type = shape_type<rank>;
full_shape_type modified_shape;
for (int i = 0; i < rank; i++) {
modified_shape.at(i) = view.extent(i);
modified_shape.at(i) = in.extent(i);
}

// Update shapes based on newly given shape
Expand All @@ -51,6 +67,12 @@ auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
modified_shape.at(positive_axis) = shape.at(i);
}

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point_v<out_value_type>;

if (is_C2R) {
int reshaped_axis = positive_axes.back();
modified_shape.at(reshaped_axis) = modified_shape.at(reshaped_axis) / 2 + 1;
Expand Down
Loading

0 comments on commit cf35a43

Please sign in to comment.