diff --git a/include/experimental/__p0009_bits/mdspan.hpp b/include/experimental/__p0009_bits/mdspan.hpp index e976ab87..44c884d3 100644 --- a/include/experimental/__p0009_bits/mdspan.hpp +++ b/include/experimental/__p0009_bits/mdspan.hpp @@ -22,6 +22,9 @@ #include "trait_backports.hpp" #include "compressed_pair.hpp" +#include +#include + namespace MDSPAN_IMPL_STANDARD_NAMESPACE { template < class ElementType, @@ -218,6 +221,68 @@ class mdspan //-------------------------------------------------------------------------------- // [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element + MDSPAN_TEMPLATE_REQUIRES( + class... SizeTypes, + /* requires */ ( + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_convertible, SizeTypes, index_type) /* && ... */) && + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, SizeTypes) /* && ... */) && + (rank() == sizeof...(SizeTypes)) + ) + ) + constexpr reference at(SizeTypes... indices) const + { + size_t r = 0; + for (const auto& index : {indices...}) { + if (index >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + ++r; + } + return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast(std::move(indices))...)); + } + + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(const std::array< SizeType, rank()>& indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (indices[r] >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + + #ifdef __cpp_lib_span + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(std::span indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (indices[r] >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + #endif // __cpp_lib_span + #if MDSPAN_USE_BRACKET_OPERATOR MDSPAN_TEMPLATE_REQUIRES( class... SizeTypes, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f8420b42..1cb8cf76 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -49,6 +49,7 @@ else() endif() mdspan_add_test(test_extents) +mdspan_add_test(test_mdspan_at) mdspan_add_test(test_mdspan_ctors) mdspan_add_test(test_mdspan_swap) mdspan_add_test(test_mdspan_conversion) diff --git a/tests/test_mdspan_at.cpp b/tests/test_mdspan_at.cpp new file mode 100644 index 00000000..55496812 --- /dev/null +++ b/tests/test_mdspan_at.cpp @@ -0,0 +1,34 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include +#include + +#include + + +TEST(TestMdspanAt, test_mdspan_at) { + std::array a{}; + Kokkos::mdspan> s(a.data()); + + s.at(0, 0) = 3.14; + s.at(std::array{1, 2}) = 2.72; + ASSERT_EQ(s.at(0, 0), 3.14); + ASSERT_EQ(s.at(std::array{1, 2}), 2.72); + + EXPECT_THROW(s.at(2, 3), std::out_of_range); + EXPECT_THROW(s.at(std::array{3, 1}), std::out_of_range); +}