Skip to content

Commit

Permalink
Merge pull request #35 from CExA-project/hotfix-utils
Browse files Browse the repository at this point in the history
[Bugfix] assertions in debug mode fixed
  • Loading branch information
yasahi-hpc authored Jan 24, 2024
2 parents f1cff77 + af48de4 commit 0ea0f17
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 32 deletions.
53 changes: 28 additions & 25 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ auto get_extents(InViewType& in, OutViewType& out, axis_type<DIM> _axes) {
// Then R2C
if (is_complex<out_value_type>::value) {
assert(out_extents.at(inner_most_axis) ==
in.extent(inner_most_axis) / 2 + 1);
in_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
Expand All @@ -57,7 +57,7 @@ auto get_extents(InViewType& in, OutViewType& out, axis_type<DIM> _axes) {
// Then C2R
if (is_complex<in_value_type>::value) {
assert(in_extents.at(inner_most_axis) ==
out.extent(inner_most_axis) / 2 + 1);
out_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
Expand Down Expand Up @@ -95,7 +95,10 @@ auto get_extents_batched(InViewType& in, OutViewType& out,
"or equal to 1.");

constexpr std::size_t rank = InViewType::rank;
int inner_most_axis = rank - 1;
int inner_most_axis =
std::is_same_v<array_layout_type, typename Kokkos::LayoutLeft>
? 0
: (rank - 1);

std::vector<int> _in_extents, _out_extents, _fft_extents;

Expand All @@ -114,29 +117,11 @@ auto get_extents_batched(InViewType& in, OutViewType& out,
_fft_extents.push_back(fft_extent);
}

if (std::is_same<array_layout_type, Kokkos::LayoutLeft>::value) {
std::reverse(_in_extents.begin(), _in_extents.end());
std::reverse(_out_extents.begin(), _out_extents.end());
std::reverse(_fft_extents.begin(), _fft_extents.end());
}

// Define subvectors starting from last - DIM
// Dimensions relevant to FFTs
std::vector<int> in_extents(_in_extents.end() - DIM, _in_extents.end());
std::vector<int> out_extents(_out_extents.end() - DIM, _out_extents.end());
std::vector<int> fft_extents(_fft_extents.end() - DIM, _fft_extents.end());

int total_fft_size = std::accumulate(_fft_extents.begin(), _fft_extents.end(),
1, std::multiplies<>());
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
int howmany = total_fft_size / fft_size;

if (std::is_floating_point<in_value_type>::value) {
// Then R2C
if (is_complex<out_value_type>::value) {
assert(out_extents.at(inner_most_axis) ==
in.extent(inner_most_axis) / 2 + 1);
assert(_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
Expand All @@ -146,14 +131,32 @@ auto get_extents_batched(InViewType& in, OutViewType& out,
if (std::is_floating_point<out_value_type>::value) {
// Then C2R
if (is_complex<in_value_type>::value) {
assert(in_extents.at(inner_most_axis) ==
out.extent(inner_most_axis) / 2 + 1);
assert(_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1);
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
}
}

if (std::is_same<array_layout_type, Kokkos::LayoutLeft>::value) {
std::reverse(_in_extents.begin(), _in_extents.end());
std::reverse(_out_extents.begin(), _out_extents.end());
std::reverse(_fft_extents.begin(), _fft_extents.end());
}

// Define subvectors starting from last - DIM
// Dimensions relevant to FFTs
std::vector<int> in_extents(_in_extents.end() - DIM, _in_extents.end());
std::vector<int> out_extents(_out_extents.end() - DIM, _out_extents.end());
std::vector<int> fft_extents(_fft_extents.end() - DIM, _fft_extents.end());

int total_fft_size = std::accumulate(_fft_extents.begin(), _fft_extents.end(),
1, std::multiplies<>());
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
int howmany = total_fft_size / fft_size;

return std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, int>(
{in_extents, out_extents, fft_extents, howmany});
}
Expand Down
9 changes: 4 additions & 5 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ auto convert_negative_axis(const ViewType& view, int _axis = -1) {
static_assert(Kokkos::is_view<ViewType>::value,
"convert_negative_axis: ViewType is not a Kokkos::View.");
int rank = static_cast<int>(ViewType::rank());
assert(abs(_axis) < rank); // axis should be in [-(rank-1), rank-1]
assert(_axis >= -rank && _axis < rank); // axis should be in [-rank, rank-1]
int axis = _axis < 0 ? rank + _axis : _axis;
return axis;
}
Expand Down Expand Up @@ -105,10 +105,9 @@ bool has_duplicate_values(const std::vector<T>& values) {
return set_values.size() < values.size();
}

template <
typename IntType, std::size_t DIM = 1,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const std::array<IntType, DIM>& values,
template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>,
std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const std::vector<IntType>& values,
IntType max) {
bool is_included = false;
for (auto value : values) {
Expand Down
2 changes: 1 addition & 1 deletion common/unit_test/Test_Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ TEST(GetIndex, Vectors) {
}

TEST(IsOutOfRangeValueIncluded, Array) {
std::array<int, 4> v = {0, 1, 2, 3};
std::vector<int> v = {0, 1, 2, 3};

EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 2));
EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 3));
Expand Down
2 changes: 1 addition & 1 deletion fft/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_subdirectory(src)
if(BUILD_TESTING)
add_subdirectory(unit_test)
endif()
endif()

0 comments on commit 0ea0f17

Please sign in to comment.