diff --git a/c++/nda/mpi/scatter.hpp b/c++/nda/mpi/scatter.hpp index be0e5f6f..8a0115e3 100644 --- a/c++/nda/mpi/scatter.hpp +++ b/c++/nda/mpi/scatter.hpp @@ -21,16 +21,20 @@ #pragma once -#include "../basic_functions.hpp" +#include "./utils.hpp" #include "../concepts.hpp" -#include "../exceptions.hpp" +#include "../macros.hpp" #include "../traits.hpp" +#include #include +#include +#include +#include +#include #include #include -#include /** * @ingroup av_mpi @@ -39,9 +43,10 @@ * @details An object of this class is returned when scattering nda::Array objects across multiple MPI processes. * * It models an nda::ArrayInitializer, that means it can be used to initialize and assign to nda::basic_array and - * nda::basic_array_view objects. The input array will be a chunked along its first dimension using `mpi::chunk_length`. + * nda::basic_array_view objects. The input array/view on the root process will be chunked along the first dimension + * into equal parts using `mpi::chunk_length` and scattered across all processes in the communicator. * - * See nda::mpi_scatter for an example. + * See nda::mpi_scatter for an example and more information. * * @tparam A nda::Array type to be scattered. */ @@ -50,11 +55,11 @@ struct mpi::lazy { /// Value type of the array/view. using value_type = typename std::decay_t::value_type; - /// Const view type of the array/view stored in the lazy object. - using const_view_type = decltype(std::declval()()); + /// Type of the array/view stored in the lazy object. + using stored_type = A; - /// View of the array/view to be scattered. - const_view_type rhs; + /// Array/View to be scattered. + stored_type rhs; /// MPI communicator. mpi::communicator comm; @@ -65,37 +70,45 @@ struct mpi::lazy { /// Should all processes receive the result. (doesn't make sense for scatter) const bool all{false}; // NOLINT (const is fine here) + /// Size of the array/view to be scattered. + mutable long scatter_size{0}; + /** - * @brief Compute the shape of the target array. + * @brief Compute the shape of the nda::ArrayInitializer object. + * + * @details The input array/view on the root process is chunked along the first dimension into equal (as much as + * possible) parts using `mpi::chunk_length`. * - * @details The target shape will be the same as the input shape, except that the first dimension of the input array - * is chunked into equal (as much as possible) parts using `mpi::chunk_length` and assigned to each MPI process. + * If the extent of the input array along the first dimension is not divisible by the number of processes, processes + * with lower ranks will receive more data than processes with higher ranks. * * @warning This makes an MPI call. * - * @return Shape of the target array. + * @return Shape of the nda::ArrayInitializer object. */ [[nodiscard]] auto shape() const { auto dims = rhs.shape(); - long dim0 = dims[0]; - mpi::broadcast(dim0, comm, root); - dims[0] = mpi::chunk_length(dim0, comm.size(), comm.rank()); + mpi::broadcast(dims, comm, root); + scatter_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank()); return dims; } /** * @brief Execute the lazy MPI operation and write the result to a target array/view. * - * @tparam T nda::Array type of the target array/view. + * @details The data will be scattered directly into the memory handle of the target array/view. + * + * Throws an exception, if the target array/view is not contiguous with positive strides or if a target view does not + * have the correct shape. + * + * @tparam T nda::Array type with C-layout. * @param target Target array/view. */ template + requires(std::decay_t::is_stride_order_C()) void invoke(T &&target) const { // NOLINT (temporary views are allowed here) - if (not target.is_contiguous() or not target.has_positive_strides()) - NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Target array needs to be contiguous with positive strides"; - - static_assert(std::decay_t::layout_t::stride_order_encoded == std::decay_t::layout_t::stride_order_encoded, - "Error in MPI scatter for nda::Array: Incompatible stride orders"); + using namespace nda::detail; // special case for non-mpi runs if (not mpi::has_env) { @@ -103,24 +116,17 @@ struct mpi::lazy { return; } - // get target shape and resize or check the target array + // check if the target array/view can be used in the MPI call + check_mpi_contiguous_layout(target, "mpi_scatter"); + + // get target shape and resize or check the target array/view auto dims = shape(); resize_or_check_if_view(target, dims); - // compute send counts, receive counts and memory displacements - auto dim0 = rhs.extent(0); - auto stride0 = rhs.indexmap().strides()[0]; - auto sendcounts = std::vector(comm.size()); - auto displs = std::vector(comm.size() + 1, 0); - int recvcount = mpi::chunk_length(dim0, comm.size(), comm.rank()) * stride0; - for (int r = 0; r < comm.size(); ++r) { - sendcounts[r] = mpi::chunk_length(dim0, comm.size(), r) * stride0; - displs[r + 1] = sendcounts[r] + displs[r]; - } - // scatter the data - auto mpi_value_type = mpi::mpi_type::get(); - MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), recvcount, mpi_value_type, root, comm.get()); + auto target_span = std::span{target.data(), static_cast(target.size())}; + auto rhs_span = std::span{rhs.data(), static_cast(rhs.size())}; + mpi::scatter_range(rhs_span, target_span, scatter_size, comm, root, rhs.indexmap().strides()[0]); } }; @@ -130,36 +136,51 @@ namespace nda { * @ingroup av_mpi * @brief Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types. * - * @details Since the returned `mpi::lazy` object models an nda::ArrayInitializer, it can be used to initialize/assign - * to nda::basic_array and nda::basic_array_view objects: + * @details The function scatters a C-ordered input array/view from a root process across all processes in the given + * communicator. The array/view is chunked into equal parts along the first dimension using `mpi::chunk_length`. + * + * Throws an exception, if the given array/view on the root process is not contiguous with positive strides or if it + * doesn't have a C-layout. Furthermore, it is expected that the input arrays/views have the same rank on all + * processes. + * + * This function is lazy, i.e. it returns an mpi::lazy object without performing the actual MPI + * operation. Since the returned object models an nda::ArrayInitializer, it can be used to initialize/assign to + * nda::basic_array and nda::basic_array_view objects: * * @code{.cpp} * // create an array on all processes - * nda::array arr(10, 4); + * nda::array A(10, 4); * * // ... * // fill array on root process * // ... * * // scatter the array to all processes - * nda::array res = mpi::scatter(arr); + * nda::array B = mpi::scatter(A); * @endcode * - * Here, the array `res` will have a shape of `(10 / comm.size(), 4)`. + * Here, the array `B` has the shape `(10 / comm.size(), 4)` on each process (assuming that 10 is a multiple of + * `comm.size()`). + * + * @warning MPI calls are done in the `invoke` and `shape` methods of the `mpi::lazy` object. If one rank calls one of + * these methods, all ranks in the communicator need to call the same method. Otherwise, the program will deadlock. * * @tparam A nda::basic_array or nda::basic_array_view type. * @param a Array or view to be scattered. * @param comm `mpi::communicator` object. * @param root Rank of the root process. * @param all Should all processes receive the result of the scatter (not used). - * @return An `mpi::lazy` object modelling an nda::ArrayInitializer. + * @return An mpi::lazy object modelling an nda::ArrayInitializer. */ template ArrayInitializer> auto mpi_scatter(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false) requires(is_regular_or_view_v) { - if (not a.is_contiguous() or not a.has_positive_strides()) - NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Array needs to be contiguous with positive strides"; + EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_ranks(a, comm), "Ranks of arrays/views must be equal in nda::mpi_scatter") + if (comm.rank() == root) { + detail::check_mpi_contiguous_layout(a, "mpi_scatter"); + detail::check_mpi_c_layout(a, "mpi_scatter"); + } return mpi::lazy{std::forward(a), comm, root, all}; } diff --git a/test/c++/nda_mpi.cpp b/test/c++/nda_mpi.cpp index 00490965..8bdc141e 100644 --- a/test/c++/nda_mpi.cpp +++ b/test/c++/nda_mpi.cpp @@ -216,6 +216,28 @@ TEST_F(NDAMpi, GatherOtherLayouts) { } } +TEST_F(NDAMpi, ScatterCLayout) { + // scatter a C-layout array + decltype(A) A_scatter = mpi::scatter(A, comm); + auto chunked_rg = itertools::chunk_range(0, A.shape()[0], comm.size(), comm.rank()); + auto exp_shape = std::array{chunked_rg.second - chunked_rg.first, shape_3d[1], shape_3d[2]}; + EXPECT_EQ(exp_shape, A_scatter.shape()); + EXPECT_ARRAY_EQ(A(nda::range(chunked_rg.first, chunked_rg.second), nda::ellipsis{}), A_scatter); +} + +TEST_F(NDAMpi, ScatterOtherLayouts) { + // scatter a non C-layout array by first reshaping it + constexpr auto perm = decltype(A2)::layout_t::stride_order; + constexpr auto inv_perm = nda::permutations::inverse(perm); + + decltype(A) A2_scatter = mpi::scatter(nda::permuted_indices_view(A2), comm); + auto A2_scatter_v = nda::permuted_indices_view(A2_scatter); + auto chunked_rg = itertools::chunk_range(0, A2.shape()[1], comm.size(), comm.rank()); + auto exp_shape = std::array{shape_3d[0], chunked_rg.second - chunked_rg.first, shape_3d[2]}; + EXPECT_EQ(exp_shape, A2_scatter_v.shape()); + EXPECT_ARRAY_EQ(A2(nda::range::all, nda::range(chunked_rg.first, chunked_rg.second), nda::range::all), A2_scatter_v); +} + TEST_F(NDAMpi, ReduceCLayout) { // reduce an array decltype(A) A_sum = mpi::reduce(A, comm); @@ -263,16 +285,6 @@ TEST_F(NDAMpi, ReduceCustomType) { EXPECT_ARRAY_EQ(B_sum, exp_sum); } -TEST_F(NDAMpi, Scatter) { - // scatter an array - decltype(A) A_scatter = mpi::scatter(A, comm); - auto chunked_rg = itertools::chunk_range(0, A.shape()[0], comm.size(), comm.rank()); - auto exp_shape = A.shape(); - exp_shape[0] = chunked_rg.second - chunked_rg.first; - EXPECT_EQ(exp_shape, A_scatter.shape()); - EXPECT_ARRAY_EQ(A(nda::range(chunked_rg.first, chunked_rg.second), nda::ellipsis{}), A_scatter); -} - TEST_F(NDAMpi, BroadcastTransposedMatrix) { nda::matrix> M_t = transpose(M); nda::matrix> N; @@ -327,4 +339,21 @@ TEST_F(NDAMpi, VariousCollectiveCommunications) { EXPECT_ARRAY_NEAR(R2, comm.size() * A); } +TEST_F(NDAMpi, PassingTemporaryObjects) { + auto A = nda::array{1, 2, 3}; + auto lazy_arr = mpi::gather(nda::array{1, 2, 3}, comm); + auto res_arr = nda::array(lazy_arr); + auto lazy_view = mpi::gather(A(), comm); + auto res_view = nda::array(lazy_view); + if (comm.rank() == 0) { + for (long i = 0; i < comm.size(); ++i) { + EXPECT_ARRAY_EQ(res_arr(nda::range(i * 3, (i + 1) * 3)), A); + EXPECT_ARRAY_EQ(res_view(nda::range(i * 3, (i + 1) * 3)), A); + } + } else { + EXPECT_TRUE(res_arr.empty()); + EXPECT_TRUE(res_view.empty()); + } +} + MPI_TEST_MAIN