From 7fb28eba5f162fda61b7c02d7e10e03a23647bd9 Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Wed, 8 Nov 2023 16:54:47 +0100 Subject: [PATCH] #2094: Add unit tests for multicast --- examples/hello_world/objgroup.cc | 4 ++-- src/vt/group/group_manager.h | 17 +++++++++++--- src/vt/objgroup/proxy/proxy_objgroup.impl.h | 6 ++++- tests/unit/objgroup/test_objgroup.cc | 26 +++++++++++++++++++++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/examples/hello_world/objgroup.cc b/examples/hello_world/objgroup.cc index 77df7f3921..6642bc7ee8 100644 --- a/examples/hello_world/objgroup.cc +++ b/examples/hello_world/objgroup.cc @@ -60,7 +60,7 @@ int main(int argc, char** argv) { auto proxy = vt::theObjGroup()->makeCollective("examples_hello_world"); - // Create group of odd nodes and broadcast to them (from root node) + // Create group of odd nodes and multicast to them (from root node) vt::theGroup()->newGroupCollective( this_node % 2, [proxy, this_node](::vt::GroupType type) { if (this_node == 0) { @@ -83,7 +83,7 @@ int main(int argc, char** argv) { using namespace ::vt::group::region; - // Create list of nodes and broadcast to them + // Create list of nodes and multicast to them List::ListType range; for (vt::NodeType node = 0; node < num_nodes; ++node) { if (node % 2 == 0) { diff --git a/src/vt/group/group_manager.h b/src/vt/group/group_manager.h index 4b4c270284..23cdaf0e63 100644 --- a/src/vt/group/group_manager.h +++ b/src/vt/group/group_manager.h @@ -121,12 +121,22 @@ struct GroupManager : runtime::component::Component { */ void setupDefaultGroup(); - void AddNewTempGroup(const region::Region::ListType& key, GroupType value) { - temporary_groups_[key] = value; + /** + * \internal \brief Cache group created by multicast. This allows for reusing the same group. + * + * \param[in] range list of nodes that are part of given group + * \param[in] group group to cache + */ + void AddNewTempGroup(const region::Region::ListType& range, GroupType group) { + temporary_groups_[range] = group; } + /** + * \internal \brief Return (if any) group associated with given range of nodes + */ std::optional GetTempGroupForRange(const region::Region::ListType& range); + /** * \brief Create a new rooted group. * @@ -438,7 +448,8 @@ struct GroupManager : runtime::component::Component { ActionContainerType continuation_actions_ = {}; ActionListType cleanup_actions_ = {}; CollectiveScopeType collective_scope_; - std::unordered_map temporary_groups_ = {}; + std::unordered_map + temporary_groups_ = {}; }; /** diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index 813d57d3d4..906e6d0738 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -138,7 +138,11 @@ template template typename Proxy::PendingSendType Proxy::multicast( group::region::Region::RegionUPtrType&& nodes, Params&&... params) const { - // This will work for list-type ranges only + vtAssert( + not dynamic_cast(nodes.get()), + "multicast: range of nodes is not supported for ShallowList!" + ); + nodes->sort(); auto& range = nodes->makeList(); diff --git a/tests/unit/objgroup/test_objgroup.cc b/tests/unit/objgroup/test_objgroup.cc index 102217b942..ddfc690cd0 100644 --- a/tests/unit/objgroup/test_objgroup.cc +++ b/tests/unit/objgroup/test_objgroup.cc @@ -289,6 +289,32 @@ TEST_F(TestObjGroup, test_proxy_invoke) { EXPECT_EQ(proxy.get()->recv_, 3); } +TEST_F(TestObjGroup, test_proxy_multicast) { + using namespace ::vt::group::region; + auto const this_node = theContext()->getNode(); + auto const num_nodes = theContext()->getNumNodes(); + + auto proxy = + vt::theObjGroup()->makeCollective("test_proxy_multicast"); + + vt::runInEpochCollective([this_node, num_nodes, proxy] { + if (this_node == 0) { + // Create list of nodes and multicast to them + List::ListType range; + for (vt::NodeType node = 0; node < num_nodes; ++node) { + if (node % 2 == 0) { + range.push_back(node); + } + } + + proxy.multicast<&MyObjA::handler>(std::make_unique(range)); + } + }); + + const auto expected = this_node % 2 == 0 ? 1 : 0; + EXPECT_EQ(proxy.get()->recv_, expected); +} + TEST_F(TestObjGroup, test_pending_send) { auto my_node = vt::theContext()->getNode(); // create a proxy to a object group