Skip to content

Commit

Permalink
#2281: Update documentation for new allreduce algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 18, 2024
1 parent 934520d commit 852e53c
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 60 deletions.
21 changes: 0 additions & 21 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ class DataHandler {
return {};
}

static DataType fromMemory(const Scalar*, size_t) {
vtAssert(
true,
"Using default DataHandler! This means that you're using custom type for "
"allreduce. Please provide specialization for you data type."
);

return {};
}

static size_t size(void) {
vtAssert(
true,
Expand All @@ -106,9 +96,6 @@ class DataHandler<ScalarType, typename std::enable_if<std::is_arithmetic<ScalarT

static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
static ScalarType fromMemory(const ScalarType* data, size_t) {
return *data;
}

static size_t size(const ScalarType&) { return 1; }
};
Expand All @@ -120,10 +107,6 @@ class DataHandler<std::vector<T>> {

static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromMemory(const T* data, size_t count) {
return std::vector<T>(data, data + count);
}

static size_t size(const std::vector<T>& data) { return data.size(); }
};

Expand All @@ -140,10 +123,6 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
return std::vector<T>(data.data(), data.data() + data.extent(0));
}

static ViewType fromMemory(T* data, size_t size) {
return ViewType(data, size);
}

static ViewType fromVec(const std::vector<T>& data) {
ViewType view("view", data.size());
auto data_view = Kokkos::View<const T*, Kokkos::HostSpace, Kokkos::MemoryUnmanaged>(data.data(), data.size());
Expand Down
71 changes: 59 additions & 12 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,15 @@
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H

#include "vt/config.h"
#include "vt/context/context.h"
#include "vt/messaging/message/message.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"
#include "vt/registry/auto/auto_registry.h"
#include "vt/pipe/pipe_manager.h"
#include "data_handler.h"
#include "type.h"
#include "rabenseifner_msg.h"
#include "helpers.h"
#include "vt/utils/fntraits/fntraits.h"
#include "vt/configs/types/types_type.h"

#include <cstdint>

namespace vt::collective::reduce::allreduce {

struct CollectionAllreduceT {};
struct GroupAllreduceT {};
struct ObjgroupAllreduceT {};

/**
* \struct Rabenseifner
* \brief Class implementing Rabenseifner's allreduce algorithm.
Expand All @@ -74,31 +62,71 @@ struct ObjgroupAllreduceT {};
*/

struct Rabenseifner {
/**
* \brief Constructor for Collection
*
* \param proxy Collection proxy
* \param group GroupID (for given collection)
* \param num_elems Number of local collection elements
*/
Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems);

/**
* \brief Constructor for Group
*
* \param group GroupID
*/
Rabenseifner(detail::StrongGroup group);

/**
* \brief Constructor for ObjGroup
*
* \param objgroup ObjGroupProxy
*/
Rabenseifner(detail::StrongObjGroup objgroup);
~Rabenseifner();

/**
* \brief Set final handler that will be executed with allreduce result
*
* \param fin Callback to be executed
* \param id Allreduce ID
*/
template <typename DataT, typename CallbackType>
void setFinalHandler(const CallbackType& fin, size_t id);

/**
* \brief Performs local reduce, and once the local one is done it starts up the global allreduce
*
* \param id Allreduce ID
* \param args Data to be allreduced
*/
template <typename DataT, template <typename Arg> class Op, typename... Args>
void localReduce(size_t id, Args&&... args);

/**
* \brief Initialize the allreduce algorithm.
*
* This function sets up the necessary data structures and initial values for the reduction operation.
*
* \param id Allreduce ID
* \param args Additional arguments for initializing the data value.
*/
template <typename DataT, typename ...Args>
void initialize(size_t id, Args&&... args);

/**
* \brief Initialize the internal state of allreduce algorithm.
*
* \param id Allreduce ID
*/
template <typename DataT>
void initializeState(size_t id);

/**
* \brief Execute the final handler callback with the reduced result.
*
* \param id Allreduce ID
*/
template <typename DataT>
void executeFinalHan(size_t id);
Expand All @@ -107,6 +135,8 @@ struct Rabenseifner {
* \brief Perform the allreduce operation.
*
* This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary.
*
* \param id Allreduce ID
*/
template <typename DataT, template <typename Arg> class Op>
void allreduce(size_t id);
Expand All @@ -116,6 +146,8 @@ struct Rabenseifner {
*
* This function performs additional steps to handle non-power-of-two process counts, ensuring that the
* main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes.
*
* \param id Allreduce ID
*/
template <typename DataT, template <typename Arg> class Op>
void adjustForPowerOfTwo(size_t id);
Expand Down Expand Up @@ -153,6 +185,7 @@ struct Rabenseifner {
/**
* \brief Check if all scatter messages have been received.
*
* \param id Allreduce ID
* \return True if all scatter messages have been received, false otherwise.
*/
template <typename DataT>
Expand All @@ -161,6 +194,7 @@ struct Rabenseifner {
/**
* \brief Check if the scatter phase is complete.
*
* \param id Allreduce ID
* \return True if the scatter phase is complete, false otherwise.
*/
template <typename DataT>
Expand All @@ -169,6 +203,7 @@ struct Rabenseifner {
/**
* \brief Check if the scatter phase is ready to proceed.
*
* \param id Allreduce ID
* \return True if the scatter phase is ready to proceed, false otherwise.
*/
template <typename DataT>
Expand All @@ -186,6 +221,8 @@ struct Rabenseifner {
* \brief Perform the scatter-reduce iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase.
*
* \param id Allreduce ID
*/
template <typename DataT, template <typename Arg> class Op>
void scatterReduceIter(size_t id);
Expand All @@ -203,6 +240,7 @@ struct Rabenseifner {
/**
* \brief Check if all gather messages have been received.
*
* \param id Allreduce ID
* \return True if all gather messages have been received, false otherwise.
*/
template <typename DataT>
Expand All @@ -211,6 +249,7 @@ struct Rabenseifner {
/**
* \brief Check if the gather phase is complete.
*
* \param id Allreduce ID
* \return True if the gather phase is complete, false otherwise.
*/
template <typename DataT>
Expand All @@ -219,6 +258,7 @@ struct Rabenseifner {
/**
* \brief Check if the gather phase is ready to proceed.
*
* \param id Allreduce ID
* \return True if the gather phase is ready to proceed, false otherwise.
*/
template <typename DataT>
Expand All @@ -227,6 +267,7 @@ struct Rabenseifner {
/**
* \brief Try to reduce the received gather messages.
*
* \param id Allreduce ID
* \param step The current step in the gather phase.
*/
template <typename DataT>
Expand All @@ -236,6 +277,8 @@ struct Rabenseifner {
* \brief Perform the gather iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the gather phase.
*
* \param id Allreduce ID
*/
template <typename DataT>
void gatherIter(size_t id);
Expand All @@ -254,6 +297,8 @@ struct Rabenseifner {
* \brief Perform the final part of the allreduce operation.
*
* This function completes the allreduce operation, handling any remaining steps and invoking the final handler.
*
* \param id Allreduce ID
*/
template <typename DataT>
void finalPart(size_t id);
Expand All @@ -262,6 +307,8 @@ struct Rabenseifner {
* \brief Send the result to excluded nodes.
*
* This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes.
*
* \param id Allreduce ID
*/
template <typename DataT>
void sendToExcludedNodes(size_t id);
Expand Down
Loading

0 comments on commit 852e53c

Please sign in to comment.