diff --git a/common/src/KokkosFFT_layouts.hpp b/common/src/KokkosFFT_layouts.hpp index 4e474431..aea31d8c 100644 --- a/common/src/KokkosFFT_layouts.hpp +++ b/common/src/KokkosFFT_layouts.hpp @@ -46,7 +46,7 @@ auto get_extents(InViewType& in, OutViewType& out, axis_type _axes) { // Then R2C if (is_complex::value) { assert(out_extents.at(inner_most_axis) == - in.extent(inner_most_axis) / 2 + 1); + in_extents.at(inner_most_axis) / 2 + 1); } else { throw std::runtime_error( "If the input type is real, the output type should be complex"); @@ -57,7 +57,7 @@ auto get_extents(InViewType& in, OutViewType& out, axis_type _axes) { // Then C2R if (is_complex::value) { assert(in_extents.at(inner_most_axis) == - out.extent(inner_most_axis) / 2 + 1); + out_extents.at(inner_most_axis) / 2 + 1); } else { throw std::runtime_error( "If the output type is real, the input type should be complex"); @@ -95,7 +95,10 @@ auto get_extents_batched(InViewType& in, OutViewType& out, "or equal to 1."); constexpr std::size_t rank = InViewType::rank; - int inner_most_axis = rank - 1; + int inner_most_axis = + std::is_same_v + ? 0 + : (rank - 1); std::vector _in_extents, _out_extents, _fft_extents; @@ -114,29 +117,11 @@ auto get_extents_batched(InViewType& in, OutViewType& out, _fft_extents.push_back(fft_extent); } - if (std::is_same::value) { - std::reverse(_in_extents.begin(), _in_extents.end()); - std::reverse(_out_extents.begin(), _out_extents.end()); - std::reverse(_fft_extents.begin(), _fft_extents.end()); - } - - // Define subvectors starting from last - DIM - // Dimensions relevant to FFTs - std::vector in_extents(_in_extents.end() - DIM, _in_extents.end()); - std::vector out_extents(_out_extents.end() - DIM, _out_extents.end()); - std::vector fft_extents(_fft_extents.end() - DIM, _fft_extents.end()); - - int total_fft_size = std::accumulate(_fft_extents.begin(), _fft_extents.end(), - 1, std::multiplies<>()); - int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, - std::multiplies<>()); - int howmany = total_fft_size / fft_size; - if (std::is_floating_point::value) { // Then R2C if (is_complex::value) { - assert(out_extents.at(inner_most_axis) == - in.extent(inner_most_axis) / 2 + 1); + assert(_out_extents.at(inner_most_axis) == + _in_extents.at(inner_most_axis) / 2 + 1); } else { throw std::runtime_error( "If the input type is real, the output type should be complex"); @@ -146,14 +131,32 @@ auto get_extents_batched(InViewType& in, OutViewType& out, if (std::is_floating_point::value) { // Then C2R if (is_complex::value) { - assert(in_extents.at(inner_most_axis) == - out.extent(inner_most_axis) / 2 + 1); + assert(_in_extents.at(inner_most_axis) == + _out_extents.at(inner_most_axis) / 2 + 1); } else { throw std::runtime_error( "If the output type is real, the input type should be complex"); } } + if (std::is_same::value) { + std::reverse(_in_extents.begin(), _in_extents.end()); + std::reverse(_out_extents.begin(), _out_extents.end()); + std::reverse(_fft_extents.begin(), _fft_extents.end()); + } + + // Define subvectors starting from last - DIM + // Dimensions relevant to FFTs + std::vector in_extents(_in_extents.end() - DIM, _in_extents.end()); + std::vector out_extents(_out_extents.end() - DIM, _out_extents.end()); + std::vector fft_extents(_fft_extents.end() - DIM, _fft_extents.end()); + + int total_fft_size = std::accumulate(_fft_extents.begin(), _fft_extents.end(), + 1, std::multiplies<>()); + int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, + std::multiplies<>()); + int howmany = total_fft_size / fft_size; + return std::tuple, std::vector, std::vector, int>( {in_extents, out_extents, fft_extents, howmany}); } diff --git a/common/src/KokkosFFT_utils.hpp b/common/src/KokkosFFT_utils.hpp index 40b43795..e80b0d3c 100644 --- a/common/src/KokkosFFT_utils.hpp +++ b/common/src/KokkosFFT_utils.hpp @@ -58,7 +58,7 @@ auto convert_negative_axis(const ViewType& view, int _axis = -1) { static_assert(Kokkos::is_view::value, "convert_negative_axis: ViewType is not a Kokkos::View."); int rank = static_cast(ViewType::rank()); - assert(abs(_axis) < rank); // axis should be in [-(rank-1), rank-1] + assert(_axis >= -rank && _axis < rank); // axis should be in [-rank, rank-1] int axis = _axis < 0 ? rank + _axis : _axis; return axis; } @@ -105,10 +105,9 @@ bool has_duplicate_values(const std::vector& values) { return set_values.size() < values.size(); } -template < - typename IntType, std::size_t DIM = 1, - std::enable_if_t, std::nullptr_t> = nullptr> -bool is_out_of_range_value_included(const std::array& values, +template , + std::nullptr_t> = nullptr> +bool is_out_of_range_value_included(const std::vector& values, IntType max) { bool is_included = false; for (auto value : values) { diff --git a/common/unit_test/Test_Utils.cpp b/common/unit_test/Test_Utils.cpp index bef3d2b9..f44dff99 100644 --- a/common/unit_test/Test_Utils.cpp +++ b/common/unit_test/Test_Utils.cpp @@ -246,7 +246,7 @@ TEST(GetIndex, Vectors) { } TEST(IsOutOfRangeValueIncluded, Array) { - std::array v = {0, 1, 2, 3}; + std::vector v = {0, 1, 2, 3}; EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 2)); EXPECT_TRUE(KokkosFFT::Impl::is_out_of_range_value_included(v, 3)); diff --git a/fft/CMakeLists.txt b/fft/CMakeLists.txt index faf5524b..ffddf45b 100644 --- a/fft/CMakeLists.txt +++ b/fft/CMakeLists.txt @@ -1,4 +1,4 @@ add_subdirectory(src) if(BUILD_TESTING) add_subdirectory(unit_test) -endif() +endif() \ No newline at end of file