From 852e53c5afa0ff2234d39f3bc8b383fb4c1742b4 Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Wed, 18 Sep 2024 23:56:58 +0200 Subject: [PATCH] #2281: Update documentation for new allreduce algorithms --- .../reduce/allreduce/data_handler.h | 21 ----- .../reduce/allreduce/rabenseifner.h | 71 ++++++++++++++--- .../reduce/allreduce/recursive_doubling.h | 76 +++++++++++++------ src/vt/group/group_manager.impl.h | 1 + tests/unit/objgroup/test_objgroup_common.h | 3 - 5 files changed, 112 insertions(+), 60 deletions(-) diff --git a/src/vt/collective/reduce/allreduce/data_handler.h b/src/vt/collective/reduce/allreduce/data_handler.h index 6ecb133579..1d190b49d2 100644 --- a/src/vt/collective/reduce/allreduce/data_handler.h +++ b/src/vt/collective/reduce/allreduce/data_handler.h @@ -78,16 +78,6 @@ class DataHandler { return {}; } - static DataType fromMemory(const Scalar*, size_t) { - vtAssert( - true, - "Using default DataHandler! This means that you're using custom type for " - "allreduce. Please provide specialization for you data type." - ); - - return {}; - } - static size_t size(void) { vtAssert( true, @@ -106,9 +96,6 @@ class DataHandler toVec(const ScalarType& data) { return std::vector{data}; } static ScalarType fromVec(const std::vector& data) { return data[0]; } - static ScalarType fromMemory(const ScalarType* data, size_t) { - return *data; - } static size_t size(const ScalarType&) { return 1; } }; @@ -120,10 +107,6 @@ class DataHandler> { static const std::vector& toVec(const std::vector& data) { return data; } static std::vector fromVec(const std::vector& data) { return data; } - static std::vector fromMemory(const T* data, size_t count) { - return std::vector(data, data + count); - } - static size_t size(const std::vector& data) { return data.size(); } }; @@ -140,10 +123,6 @@ class DataHandler> { return std::vector(data.data(), data.data() + data.extent(0)); } - static ViewType fromMemory(T* data, size_t size) { - return ViewType(data, size); - } - static ViewType fromVec(const std::vector& data) { ViewType view("view", data.size()); auto data_view = Kokkos::View(data.data(), data.size()); diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index 1735ff7e43..98cde0b1bf 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -44,27 +44,15 @@ #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H -#include "vt/config.h" -#include "vt/context/context.h" -#include "vt/messaging/message/message.h" #include "vt/objgroup/proxy/proxy_objgroup.h" -#include "vt/registry/auto/auto_registry.h" -#include "vt/pipe/pipe_manager.h" -#include "data_handler.h" #include "type.h" #include "rabenseifner_msg.h" -#include "helpers.h" -#include "vt/utils/fntraits/fntraits.h" #include "vt/configs/types/types_type.h" #include namespace vt::collective::reduce::allreduce { -struct CollectionAllreduceT {}; -struct GroupAllreduceT {}; -struct ObjgroupAllreduceT {}; - /** * \struct Rabenseifner * \brief Class implementing Rabenseifner's allreduce algorithm. @@ -74,31 +62,71 @@ struct ObjgroupAllreduceT {}; */ struct Rabenseifner { + /** + * \brief Constructor for Collection + * + * \param proxy Collection proxy + * \param group GroupID (for given collection) + * \param num_elems Number of local collection elements + */ Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems); + + /** + * \brief Constructor for Group + * + * \param group GroupID + */ Rabenseifner(detail::StrongGroup group); + + /** + * \brief Constructor for ObjGroup + * + * \param objgroup ObjGroupProxy + */ Rabenseifner(detail::StrongObjGroup objgroup); ~Rabenseifner(); + /** + * \brief Set final handler that will be executed with allreduce result + * + * \param fin Callback to be executed + * \param id Allreduce ID + */ template void setFinalHandler(const CallbackType& fin, size_t id); + /** + * \brief Performs local reduce, and once the local one is done it starts up the global allreduce + * + * \param id Allreduce ID + * \param args Data to be allreduced + */ template class Op, typename... Args> void localReduce(size_t id, Args&&... args); + /** * \brief Initialize the allreduce algorithm. * * This function sets up the necessary data structures and initial values for the reduction operation. * + * \param id Allreduce ID * \param args Additional arguments for initializing the data value. */ template void initialize(size_t id, Args&&... args); + /** + * \brief Initialize the internal state of allreduce algorithm. + * + * \param id Allreduce ID + */ template void initializeState(size_t id); /** * \brief Execute the final handler callback with the reduced result. + * + * \param id Allreduce ID */ template void executeFinalHan(size_t id); @@ -107,6 +135,8 @@ struct Rabenseifner { * \brief Perform the allreduce operation. * * This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary. + * + * \param id Allreduce ID */ template class Op> void allreduce(size_t id); @@ -116,6 +146,8 @@ struct Rabenseifner { * * This function performs additional steps to handle non-power-of-two process counts, ensuring that the * main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes. + * + * \param id Allreduce ID */ template class Op> void adjustForPowerOfTwo(size_t id); @@ -153,6 +185,7 @@ struct Rabenseifner { /** * \brief Check if all scatter messages have been received. * + * \param id Allreduce ID * \return True if all scatter messages have been received, false otherwise. */ template @@ -161,6 +194,7 @@ struct Rabenseifner { /** * \brief Check if the scatter phase is complete. * + * \param id Allreduce ID * \return True if the scatter phase is complete, false otherwise. */ template @@ -169,6 +203,7 @@ struct Rabenseifner { /** * \brief Check if the scatter phase is ready to proceed. * + * \param id Allreduce ID * \return True if the scatter phase is ready to proceed, false otherwise. */ template @@ -186,6 +221,8 @@ struct Rabenseifner { * \brief Perform the scatter-reduce iteration. * * This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase. + * + * \param id Allreduce ID */ template class Op> void scatterReduceIter(size_t id); @@ -203,6 +240,7 @@ struct Rabenseifner { /** * \brief Check if all gather messages have been received. * + * \param id Allreduce ID * \return True if all gather messages have been received, false otherwise. */ template @@ -211,6 +249,7 @@ struct Rabenseifner { /** * \brief Check if the gather phase is complete. * + * \param id Allreduce ID * \return True if the gather phase is complete, false otherwise. */ template @@ -219,6 +258,7 @@ struct Rabenseifner { /** * \brief Check if the gather phase is ready to proceed. * + * \param id Allreduce ID * \return True if the gather phase is ready to proceed, false otherwise. */ template @@ -227,6 +267,7 @@ struct Rabenseifner { /** * \brief Try to reduce the received gather messages. * + * \param id Allreduce ID * \param step The current step in the gather phase. */ template @@ -236,6 +277,8 @@ struct Rabenseifner { * \brief Perform the gather iteration. * * This function sends data to the appropriate partner process and proceeds to the next step in the gather phase. + * + * \param id Allreduce ID */ template void gatherIter(size_t id); @@ -254,6 +297,8 @@ struct Rabenseifner { * \brief Perform the final part of the allreduce operation. * * This function completes the allreduce operation, handling any remaining steps and invoking the final handler. + * + * \param id Allreduce ID */ template void finalPart(size_t id); @@ -262,6 +307,8 @@ struct Rabenseifner { * \brief Send the result to excluded nodes. * * This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes. + * + * \param id Allreduce ID */ template void sendToExcludedNodes(size_t id); diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.h b/src/vt/collective/reduce/allreduce/recursive_doubling.h index 6fefe12029..3422a3c041 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.h @@ -44,22 +44,11 @@ #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RECURSIVE_DOUBLING_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RECURSIVE_DOUBLING_H -#include "vt/config.h" -#include "vt/context/context.h" -#include "vt/messaging/message/message.h" #include "vt/objgroup/proxy/proxy_objgroup.h" -#include "vt/configs/error/config_assert.h" -#include "vt/messaging/message/smart_ptr.h" -#include "data_handler.h" -#include "vt/pipe/pipe_manager.h" -#include "vt/utils/fntraits/fntraits.h" #include "type.h" #include "vt/configs/types/types_type.h" #include "vt/collective/reduce/allreduce/recursive_doubling_msg.h" -#include -#include - namespace vt::collective::reduce::allreduce { /** @@ -71,42 +60,62 @@ namespace vt::collective::reduce::allreduce { */ struct RecursiveDoubling { - RecursiveDoubling(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems); + /** - * \brief Constructor for RecursiveDoubling class. + * \brief Constructor for Collection * - * Initializes the RecursiveDoubling object with the provided parameters. + * \param proxy Collection proxy + * \param group GroupID (for given collection) + * \param num_elems Number of local collection elements + */ + RecursiveDoubling(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems); + + /** + * \brief Constructor for ObjGroup * - * \param parentProxy The parent proxy. - * \param num_nodes The number of nodes. - * \param args Additional arguments for data initialization. + * \param objgroup ObjGroupProxy */ RecursiveDoubling(detail::StrongObjGroup objgroup); - /** - * \brief Constructor for RecursiveDoubling class. - * - * Initializes the RecursiveDoubling object with the provided parameters. + /** + * \brief Constructor for Group * - * \param parentProxy The parent proxy. - * \param num_nodes The number of nodes. - * \param args Additional arguments for data initialization. + * \param group GroupID */ RecursiveDoubling(detail::StrongGroup group); ~RecursiveDoubling(); + /** + * \brief Execute the final handler callback with the reduced result. + * + * \param id Allreduce ID + */ template void executeFinalHan(size_t id); + /** + * \brief Set final handler that will be executed with allreduce result + * + * \param fin Callback to be executed + * \param id Allreduce ID + */ template void setFinalHandler(const CallbackType& fin, size_t id); + /** + * \brief Performs local reduce, and once the local one is done it starts up the global allreduce + * + * \param id Allreduce ID + * \param args Data to be allreduced + */ template class Op, typename... Args> void localReduce(size_t id, Args&&... args); /** * \brief Start the allreduce operation. + * + * \param id Allreduce ID */ template class Op> void allreduce(size_t id); @@ -114,16 +123,24 @@ struct RecursiveDoubling { /** * \brief Initialize the RecursiveDoubling object. * + * \param id Allreduce ID * \param args Additional arguments for data initialization. */ template void initialize(size_t id, Args&&... data); + /** + * \brief Initialize the internal state of allreduce algorithm. + * + * \param id Allreduce ID + */ template void initializeState(size_t id); /** * \brief Adjust for power of two nodes. + * + * \param id Allreduce ID */ template class Op> void adjustForPowerOfTwo(size_t id); @@ -139,6 +156,7 @@ struct RecursiveDoubling { /** * \brief Check if the allreduce operation is done. * + * \param id Allreduce ID * \return True if the operation is done, otherwise false. */ template @@ -147,6 +165,7 @@ struct RecursiveDoubling { /** * \brief Check if the current state is valid for allreduce. * + * \param id Allreduce ID * \return True if the state is valid, otherwise false. */ template @@ -155,6 +174,7 @@ struct RecursiveDoubling { /** * \brief Check if all messages are received for the current step. * + * \param id Allreduce ID * \return True if all messages are received, otherwise false. */ template @@ -163,6 +183,7 @@ struct RecursiveDoubling { /** * \brief Check if the object is ready for the next step of allreduce. * + * \param id Allreduce ID * \return True if ready, otherwise false. */ template @@ -170,6 +191,8 @@ struct RecursiveDoubling { /** * \brief Perform the next step of the allreduce operation. + * + * \param id Allreduce ID */ template class Op> void reduceIter(size_t id); @@ -177,6 +200,7 @@ struct RecursiveDoubling { /** * \brief Try to reduce the message at the specified step. * + * \param id Allreduce ID * \param step The step at which to try reduction. */ template class Op> @@ -192,6 +216,8 @@ struct RecursiveDoubling { /** * \brief Send data to excluded nodes for finalization. + * + * \param id Allreduce ID */ template void sendToExcludedNodes(size_t id); @@ -206,6 +232,8 @@ struct RecursiveDoubling { /** * \brief Perform the final part of the allreduce operation. + * + * \param id Allreduce ID */ template void finalPart(size_t id); diff --git a/src/vt/group/group_manager.impl.h b/src/vt/group/group_manager.impl.h index 4334d6a739..e67e743b97 100644 --- a/src/vt/group/group_manager.impl.h +++ b/src/vt/group/group_manager.impl.h @@ -57,6 +57,7 @@ #include "vt/group/group_info.h" #include "vt/collective/reduce/allreduce/rabenseifner.h" #include "vt/objgroup/manager.h" +#include "vt/pipe/pipe_manager.impl.h" namespace vt { namespace group { diff --git a/tests/unit/objgroup/test_objgroup_common.h b/tests/unit/objgroup/test_objgroup_common.h index 6219ed93c4..e9da06ead6 100644 --- a/tests/unit/objgroup/test_objgroup_common.h +++ b/tests/unit/objgroup/test_objgroup_common.h @@ -219,9 +219,6 @@ class DataHandler { static size_t size(const DataT& data) { return data.vec_.size(); } static const std::vector& toVec(const DataT& data) { return data.vec_; } static DataT fromVec(const std::vector& data) { return DataT{data}; } - static DataT fromMemory(const Scalar* data, size_t count) { - return DataT{std::vector(data, data + count)}; - } }; } // namespace vt::collective::reduce::allreduce