Skip to content

Commit

Permalink
Update mpi_scatter and add more tests
Browse files Browse the repository at this point in the history
- change the type stored in the lazy object
- make the checks more consistent
- restrict to C order arrays/views
- simplify by using the range functions from mpi
  • Loading branch information
Thoemi09 committed Oct 3, 2024
1 parent b457f61 commit fe5fbe9
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 54 deletions.
109 changes: 65 additions & 44 deletions c++/nda/mpi/scatter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@

#pragma once

#include "../basic_functions.hpp"
#include "./utils.hpp"
#include "../concepts.hpp"
#include "../exceptions.hpp"
#include "../macros.hpp"
#include "../traits.hpp"

#include <mpi.h>
#include <mpi/mpi.hpp>

#include <cstddef>
#include <functional>
#include <numeric>
#include <span>
#include <type_traits>
#include <utility>
#include <vector>

/**
* @ingroup av_mpi
Expand All @@ -39,9 +43,10 @@
* @details An object of this class is returned when scattering nda::Array objects across multiple MPI processes.
*
* It models an nda::ArrayInitializer, that means it can be used to initialize and assign to nda::basic_array and
* nda::basic_array_view objects. The input array will be a chunked along its first dimension using `mpi::chunk_length`.
* nda::basic_array_view objects. The input array/view on the root process will be chunked along the first dimension
* into equal parts using `mpi::chunk_length` and scattered across all processes in the communicator.
*
* See nda::mpi_scatter for an example.
* See nda::mpi_scatter for an example and more information.
*
* @tparam A nda::Array type to be scattered.
*/
Expand All @@ -50,11 +55,11 @@ struct mpi::lazy<mpi::tag::scatter, A> {
/// Value type of the array/view.
using value_type = typename std::decay_t<A>::value_type;

/// Const view type of the array/view stored in the lazy object.
using const_view_type = decltype(std::declval<const A>()());
/// Type of the array/view stored in the lazy object.
using stored_type = A;

/// View of the array/view to be scattered.
const_view_type rhs;
/// Array/View to be scattered.
stored_type rhs;

/// MPI communicator.
mpi::communicator comm;
Expand All @@ -65,62 +70,63 @@ struct mpi::lazy<mpi::tag::scatter, A> {
/// Should all processes receive the result. (doesn't make sense for scatter)
const bool all{false}; // NOLINT (const is fine here)

/// Size of the array/view to be scattered.
mutable long scatter_size{0};

/**
* @brief Compute the shape of the target array.
* @brief Compute the shape of the nda::ArrayInitializer object.
*
* @details The input array/view on the root process is chunked along the first dimension into equal (as much as
* possible) parts using `mpi::chunk_length`.
*
* @details The target shape will be the same as the input shape, except that the first dimension of the input array
* is chunked into equal (as much as possible) parts using `mpi::chunk_length` and assigned to each MPI process.
* If the extent of the input array along the first dimension is not divisible by the number of processes, processes
* with lower ranks will receive more data than processes with higher ranks.
*
* @warning This makes an MPI call.
*
* @return Shape of the target array.
* @return Shape of the nda::ArrayInitializer object.
*/
[[nodiscard]] auto shape() const {
auto dims = rhs.shape();
long dim0 = dims[0];
mpi::broadcast(dim0, comm, root);
dims[0] = mpi::chunk_length(dim0, comm.size(), comm.rank());
mpi::broadcast(dims, comm, root);
scatter_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank());
return dims;
}

/**
* @brief Execute the lazy MPI operation and write the result to a target array/view.
*
* @tparam T nda::Array type of the target array/view.
* @details The data will be scattered directly into the memory handle of the target array/view.
*
* Throws an exception, if the target array/view is not contiguous with positive strides or if a target view does not
* have the correct shape.
*
* @tparam T nda::Array type with C-layout.
* @param target Target array/view.
*/
template <nda::Array T>
requires(std::decay_t<T>::is_stride_order_C())
void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
if (not target.is_contiguous() or not target.has_positive_strides())
NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Target array needs to be contiguous with positive strides";

static_assert(std::decay_t<A>::layout_t::stride_order_encoded == std::decay_t<T>::layout_t::stride_order_encoded,
"Error in MPI scatter for nda::Array: Incompatible stride orders");
using namespace nda::detail;

// special case for non-mpi runs
if (not mpi::has_env) {
target = rhs;
return;
}

// get target shape and resize or check the target array
// check if the target array/view can be used in the MPI call
check_mpi_contiguous_layout(target, "mpi_scatter");

// get target shape and resize or check the target array/view
auto dims = shape();
resize_or_check_if_view(target, dims);

// compute send counts, receive counts and memory displacements
auto dim0 = rhs.extent(0);
auto stride0 = rhs.indexmap().strides()[0];
auto sendcounts = std::vector<int>(comm.size());
auto displs = std::vector<int>(comm.size() + 1, 0);
int recvcount = mpi::chunk_length(dim0, comm.size(), comm.rank()) * stride0;
for (int r = 0; r < comm.size(); ++r) {
sendcounts[r] = mpi::chunk_length(dim0, comm.size(), r) * stride0;
displs[r + 1] = sendcounts[r] + displs[r];
}

// scatter the data
auto mpi_value_type = mpi::mpi_type<value_type>::get();
MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), recvcount, mpi_value_type, root, comm.get());
auto target_span = std::span{target.data(), static_cast<std::size_t>(target.size())};
auto rhs_span = std::span{rhs.data(), static_cast<std::size_t>(rhs.size())};
mpi::scatter_range(rhs_span, target_span, scatter_size, comm, root, rhs.indexmap().strides()[0]);
}
};

Expand All @@ -130,36 +136,51 @@ namespace nda {
* @ingroup av_mpi
* @brief Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types.
*
* @details Since the returned `mpi::lazy` object models an nda::ArrayInitializer, it can be used to initialize/assign
* to nda::basic_array and nda::basic_array_view objects:
* @details The function scatters a C-ordered input array/view from a root process across all processes in the given
* communicator. The array/view is chunked into equal parts along the first dimension using `mpi::chunk_length`.
*
* Throws an exception, if the given array/view on the root process is not contiguous with positive strides or if it
* doesn't have a C-layout. Furthermore, it is expected that the input arrays/views have the same rank on all
* processes.
*
* This function is lazy, i.e. it returns an mpi::lazy<mpi::tag::scatter, A> object without performing the actual MPI
* operation. Since the returned object models an nda::ArrayInitializer, it can be used to initialize/assign to
* nda::basic_array and nda::basic_array_view objects:
*
* @code{.cpp}
* // create an array on all processes
* nda::array<int, 2> arr(10, 4);
* nda::array<int, 2> A(10, 4);
*
* // ...
* // fill array on root process
* // ...
*
* // scatter the array to all processes
* nda::array<int, 2> res = mpi::scatter(arr);
* nda::array<int, 2> B = mpi::scatter(A);
* @endcode
*
* Here, the array `res` will have a shape of `(10 / comm.size(), 4)`.
* Here, the array `B` has the shape `(10 / comm.size(), 4)` on each process (assuming that 10 is a multiple of
* `comm.size()`).
*
* @warning MPI calls are done in the `invoke` and `shape` methods of the `mpi::lazy` object. If one rank calls one of
* these methods, all ranks in the communicator need to call the same method. Otherwise, the program will deadlock.
*
* @tparam A nda::basic_array or nda::basic_array_view type.
* @param a Array or view to be scattered.
* @param comm `mpi::communicator` object.
* @param root Rank of the root process.
* @param all Should all processes receive the result of the scatter (not used).
* @return An `mpi::lazy` object modelling an nda::ArrayInitializer.
* @return An mpi::lazy<mpi::tag::scatter, A> object modelling an nda::ArrayInitializer.
*/
template <typename A>
ArrayInitializer<std::remove_reference_t<A>> auto mpi_scatter(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false)
requires(is_regular_or_view_v<A>)
{
if (not a.is_contiguous() or not a.has_positive_strides())
NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Array needs to be contiguous with positive strides";
EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_ranks(a, comm), "Ranks of arrays/views must be equal in nda::mpi_scatter")
if (comm.rank() == root) {
detail::check_mpi_contiguous_layout(a, "mpi_scatter");
detail::check_mpi_c_layout(a, "mpi_scatter");
}
return mpi::lazy<mpi::tag::scatter, A>{std::forward<A>(a), comm, root, all};
}

Expand Down
49 changes: 39 additions & 10 deletions test/c++/nda_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,28 @@ TEST_F(NDAMpi, GatherOtherLayouts) {
}
}

TEST_F(NDAMpi, ScatterCLayout) {
// scatter a C-layout array
decltype(A) A_scatter = mpi::scatter(A, comm);
auto chunked_rg = itertools::chunk_range(0, A.shape()[0], comm.size(), comm.rank());
auto exp_shape = std::array<long, 3>{chunked_rg.second - chunked_rg.first, shape_3d[1], shape_3d[2]};
EXPECT_EQ(exp_shape, A_scatter.shape());
EXPECT_ARRAY_EQ(A(nda::range(chunked_rg.first, chunked_rg.second), nda::ellipsis{}), A_scatter);
}

TEST_F(NDAMpi, ScatterOtherLayouts) {
// scatter a non C-layout array by first reshaping it
constexpr auto perm = decltype(A2)::layout_t::stride_order;
constexpr auto inv_perm = nda::permutations::inverse(perm);

decltype(A) A2_scatter = mpi::scatter(nda::permuted_indices_view<nda::encode(inv_perm)>(A2), comm);
auto A2_scatter_v = nda::permuted_indices_view<nda::encode(perm)>(A2_scatter);
auto chunked_rg = itertools::chunk_range(0, A2.shape()[1], comm.size(), comm.rank());
auto exp_shape = std::array<long, 3>{shape_3d[0], chunked_rg.second - chunked_rg.first, shape_3d[2]};
EXPECT_EQ(exp_shape, A2_scatter_v.shape());
EXPECT_ARRAY_EQ(A2(nda::range::all, nda::range(chunked_rg.first, chunked_rg.second), nda::range::all), A2_scatter_v);
}

TEST_F(NDAMpi, ReduceCLayout) {
// reduce an array
decltype(A) A_sum = mpi::reduce(A, comm);
Expand Down Expand Up @@ -263,16 +285,6 @@ TEST_F(NDAMpi, ReduceCustomType) {
EXPECT_ARRAY_EQ(B_sum, exp_sum);
}

TEST_F(NDAMpi, Scatter) {
// scatter an array
decltype(A) A_scatter = mpi::scatter(A, comm);
auto chunked_rg = itertools::chunk_range(0, A.shape()[0], comm.size(), comm.rank());
auto exp_shape = A.shape();
exp_shape[0] = chunked_rg.second - chunked_rg.first;
EXPECT_EQ(exp_shape, A_scatter.shape());
EXPECT_ARRAY_EQ(A(nda::range(chunked_rg.first, chunked_rg.second), nda::ellipsis{}), A_scatter);
}

TEST_F(NDAMpi, BroadcastTransposedMatrix) {
nda::matrix<std::complex<double>> M_t = transpose(M);
nda::matrix<std::complex<double>> N;
Expand Down Expand Up @@ -327,4 +339,21 @@ TEST_F(NDAMpi, VariousCollectiveCommunications) {
EXPECT_ARRAY_NEAR(R2, comm.size() * A);
}

TEST_F(NDAMpi, PassingTemporaryObjects) {
auto A = nda::array<int, 1>{1, 2, 3};
auto lazy_arr = mpi::gather(nda::array<int, 1>{1, 2, 3}, comm);
auto res_arr = nda::array<int, 1>(lazy_arr);
auto lazy_view = mpi::gather(A(), comm);
auto res_view = nda::array<int, 1>(lazy_view);
if (comm.rank() == 0) {
for (long i = 0; i < comm.size(); ++i) {
EXPECT_ARRAY_EQ(res_arr(nda::range(i * 3, (i + 1) * 3)), A);
EXPECT_ARRAY_EQ(res_view(nda::range(i * 3, (i + 1) * 3)), A);
}
} else {
EXPECT_TRUE(res_arr.empty());
EXPECT_TRUE(res_view.empty());
}
}

MPI_TEST_MAIN

0 comments on commit fe5fbe9

Please sign in to comment.