Skip to content

Commit

Permalink
MDSpan issues expose by Kokkos View refactor (#358)
Browse files Browse the repository at this point in the history
* Add test mapping(other_mapping) where other_mapping has none-convertible extents

Specifically the 1D layout_left_padded <-> layout_right_padded ctors
didn't compile for cases where the new mapping has a static extent,
but the source mapping has dynamic extent.

* Fix some layout_padded conversions and compilation with CUDA

* Use proper index_pair_like constraint for submdspan

* Don't use get for std::complex slice specifier

and test that complex<double> works like pair as slice specifier *barf*

* Remove out of date comment
  • Loading branch information
crtrott authored Sep 6, 2024
1 parent c2494ad commit 92a1297
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 8 deletions.
42 changes: 40 additions & 2 deletions include/experimental/__p2630_bits/submdspan_extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <tuple>
#include <complex>

#include "strided_slice.hpp"
namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
Expand Down Expand Up @@ -52,6 +53,31 @@ template <class OffsetType, class ExtentType, class StrideType>
struct is_strided_slice<
strided_slice<OffsetType, ExtentType, StrideType>> : std::true_type {};

// Helper for identifying valid pair like things
template <class T, class IndexType> struct index_pair_like : std::false_type {};

template <class IdxT1, class IdxT2, class IndexType>
struct index_pair_like<std::pair<IdxT1, IdxT2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
std::is_convertible_v<IdxT2, IndexType>;
};

template <class IdxT1, class IdxT2, class IndexType>
struct index_pair_like<std::tuple<IdxT1, IdxT2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT1, IndexType> &&
std::is_convertible_v<IdxT2, IndexType>;
};

template <class IdxT, class IndexType>
struct index_pair_like<std::complex<IdxT>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
};

template <class IdxT, class IndexType>
struct index_pair_like<std::array<IdxT, 2>, IndexType> {
static constexpr bool value = std::is_convertible_v<IdxT, IndexType>;
};

// first_of(slice): getting begin of slice specifier range
MDSPAN_TEMPLATE_REQUIRES(
class Integral,
Expand All @@ -70,13 +96,19 @@ first_of(const ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t &) {

MDSPAN_TEMPLATE_REQUIRES(
class Slice,
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
/* requires */(index_pair_like<Slice, size_t>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr auto first_of(const Slice &i) {
return std::get<0>(i);
}

template<class T>
MDSPAN_INLINE_FUNCTION
constexpr auto first_of(const std::complex<T> &i) {
return i.real();
}

template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
Expand All @@ -100,14 +132,20 @@ constexpr Integral

MDSPAN_TEMPLATE_REQUIRES(
size_t k, class Extents, class Slice,
/* requires */(std::is_convertible_v<Slice, std::tuple<size_t, size_t>>)
/* requires */(index_pair_like<Slice, size_t>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &,
const Slice &i) {
return std::get<1>(i);
}

template<size_t k, class Extents, class T>
MDSPAN_INLINE_FUNCTION
constexpr auto last_of(std::integral_constant<size_t, k>, const Extents &, const std::complex<T> &i) {
return i.imag();
}

// Suppress spurious warning with NVCC about no return statement.
// This is a known issue in NVCC and NVC++
// Depending on the CUDA and GCC version we need both the builtin
Expand Down
3 changes: 1 addition & 2 deletions include/experimental/__p2630_bits/submdspan_mapping.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ template<class SliceSpecifier, class IndexType>
struct is_range_slice {
constexpr static bool value =
std::is_same_v<SliceSpecifier, full_extent_t> ||
std::is_convertible_v<SliceSpecifier,
std::tuple<IndexType, IndexType>>;
index_pair_like<SliceSpecifier, IndexType>::value;
};

template<class SliceSpecifier, class IndexType>
Expand Down
9 changes: 5 additions & 4 deletions include/experimental/__p2642_bits/layout_padded.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct padded_extent {
using static_array_type = typename static_array_type_for_padded_extent<
padding_value, _Extents, _ExtentToPadIdx, _Extents::rank()>::type;

MDSPAN_INLINE_FUNCTION
static constexpr auto static_value() { return static_array_type::static_value(0); }

MDSPAN_INLINE_FUNCTION
Expand Down Expand Up @@ -203,7 +204,7 @@ class layout_left_padded<PaddingValue>::mapping {
}

public:
#if !MDSPAN_HAS_CXX_20
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr mapping()
: mapping(extents_type{})
Expand Down Expand Up @@ -347,7 +348,7 @@ class layout_left_padded<PaddingValue>::mapping {
MDSPAN_INLINE_FUNCTION
constexpr mapping(const _Mapping &other_mapping) noexcept
: padded_stride(padded_stride_type::init_padding(
other_mapping.extents(),
static_cast<extents_type>(other_mapping.extents()),
other_mapping.extents().extent(extent_to_pad_idx))),
exts(other_mapping.extents()) {}

Expand Down Expand Up @@ -566,7 +567,7 @@ class layout_right_padded<PaddingValue>::mapping {
}

public:
#if !MDSPAN_HAS_CXX_20
#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__)
MDSPAN_INLINE_FUNCTION_DEFAULTED
constexpr mapping()
: mapping(extents_type{})
Expand Down Expand Up @@ -707,7 +708,7 @@ class layout_right_padded<PaddingValue>::mapping {
MDSPAN_INLINE_FUNCTION
constexpr mapping(const _Mapping &other_mapping) noexcept
: padded_stride(padded_stride_type::init_padding(
other_mapping.extents(),
static_cast<extents_type>(other_mapping.extents()),
other_mapping.extents().extent(extent_to_pad_idx))),
exts(other_mapping.extents()) {}

Expand Down
4 changes: 4 additions & 0 deletions tests/test_layout_padded_left.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ TEST(LayoutLeftTests, construction)
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>(4))).padded_stride.value(0)), 4);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).extents()), (Kokkos::extents<std::size_t, 4, 7>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 7>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 4, 7>>()).padded_stride.value(0)), 4);

Expand All @@ -311,6 +313,8 @@ TEST(LayoutLeftTests, construction)
ASSERT_EQ(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t,1>(3))).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);

Expand Down
4 changes: 4 additions & 0 deletions tests/test_layout_padded_right.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,17 @@ TEST(LayoutrightTests, construction)
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, 5>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent, 5>>(Kokkos::extents<size_t, Kokkos::dynamic_extent, 5>(7))).padded_stride.value(0)), 8);
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).extents()), (Kokkos::extents<std::size_t, 7, 5>()));
ASSERT_EQ((KokkosEx::layout_right_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 7, Kokkos::dynamic_extent>>(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 7, 5>>()).padded_stride.value(0)), 8);

// Construct layout_right_padded mapping from layout_left_padded mapping
ASSERT_EQ(KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t>>()).extents(), Kokkos::extents<std::size_t>());
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>()).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<4>::mapping<Kokkos::extents<std::size_t, Kokkos::dynamic_extent>>(Kokkos::dextents<size_t, 1>(3))).padded_stride.value(0)), 0);
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).extents()), (Kokkos::extents<std::size_t, 3>()));
ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping<Kokkos::extents<std::size_t, 3>>(KokkosEx::layout_left_padded<Kokkos::dynamic_extent>::mapping<Kokkos::extents<std::size_t, 3>>({}, 4)).padded_stride.value(0)), 0);

Expand Down
11 changes: 11 additions & 0 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ using submdspan_test_types =
// layout_right to layout_right Check Extents Preservation
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,10>, Kokkos::full_extent_t>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::pair<int,int>>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t,dyn>, std::complex<double>>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10>, args_t<10>, Kokkos::extents<size_t>, int>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,10,20>, Kokkos::full_extent_t, Kokkos::full_extent_t>
, std::tuple<Kokkos::layout_right, Kokkos::layout_right, Kokkos::extents<size_t,10,20>, args_t<10,20>, Kokkos::extents<size_t,dyn,20>, std::pair<int,int>, Kokkos::full_extent_t>
Expand Down Expand Up @@ -274,6 +275,10 @@ struct TestSubMDSpan<
return std::pair<int,int>(1,3);
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(std::complex<double>) {
return std::complex<double>{1.,3.};
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(Kokkos::strided_slice<int,int,int>) {
return Kokkos::strided_slice<int,int,int>{1,3,2};
}
Expand All @@ -300,6 +305,12 @@ struct TestSubMDSpan<
}
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>, std::complex<double> p, SliceArgs ... slices) {
using idx_t = typename SubMDSpan::index_type;
return (sub_mds.extent(sub_idx)==static_cast<idx_t>(p.imag()-p.real())) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence<SrcIdx...,2>(), std::index_sequence<SubIdx...,1>(), slices...);
}
template<class SrcMDSpan, class SubMDSpan, size_t ... SrcIdx, size_t ... SubIdx, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence<SrcIdx...>, std::index_sequence<SubIdx...>,
Kokkos::strided_slice<int,int,int> p, SliceArgs ... slices) {
using idx_t = typename SubMDSpan::index_type;
Expand Down

0 comments on commit 92a1297

Please sign in to comment.