Skip to content

Commit

Permalink
#2094: Add unit tests for multicast
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Nov 8, 2023
1 parent 9699708 commit 7fb28eb
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/hello_world/objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ int main(int argc, char** argv) {
auto proxy =
vt::theObjGroup()->makeCollective<MyObjGroup>("examples_hello_world");

// Create group of odd nodes and broadcast to them (from root node)
// Create group of odd nodes and multicast to them (from root node)
vt::theGroup()->newGroupCollective(
this_node % 2, [proxy, this_node](::vt::GroupType type) {
if (this_node == 0) {
Expand All @@ -83,7 +83,7 @@ int main(int argc, char** argv) {

using namespace ::vt::group::region;

// Create list of nodes and broadcast to them
// Create list of nodes and multicast to them
List::ListType range;
for (vt::NodeType node = 0; node < num_nodes; ++node) {
if (node % 2 == 0) {
Expand Down
17 changes: 14 additions & 3 deletions src/vt/group/group_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,22 @@ struct GroupManager : runtime::component::Component<GroupManager> {
*/
void setupDefaultGroup();

void AddNewTempGroup(const region::Region::ListType& key, GroupType value) {
temporary_groups_[key] = value;
/**
* \internal \brief Cache group created by multicast. This allows for reusing the same group.
*
* \param[in] range list of nodes that are part of given group
* \param[in] group group to cache
*/
void AddNewTempGroup(const region::Region::ListType& range, GroupType group) {
temporary_groups_[range] = group;
}

/**
* \internal \brief Return (if any) group associated with given range of nodes
*/
std::optional<GroupType>
GetTempGroupForRange(const region::Region::ListType& range);

/**
* \brief Create a new rooted group.
*
Expand Down Expand Up @@ -438,7 +448,8 @@ struct GroupManager : runtime::component::Component<GroupManager> {
ActionContainerType continuation_actions_ = {};
ActionListType cleanup_actions_ = {};
CollectiveScopeType collective_scope_;
std::unordered_map<region::Region::ListType, GroupType, region::ListHash> temporary_groups_ = {};
std::unordered_map<region::Region::ListType, GroupType, region::ListHash>
temporary_groups_ = {};
};

/**
Expand Down
6 changes: 5 additions & 1 deletion src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ template <typename ObjT>
template <auto f, typename... Params>
typename Proxy<ObjT>::PendingSendType Proxy<ObjT>::multicast(
group::region::Region::RegionUPtrType&& nodes, Params&&... params) const {
// This will work for list-type ranges only
vtAssert(
not dynamic_cast<group::region::ShallowList*>(nodes.get()),
"multicast: range of nodes is not supported for ShallowList!"
);

nodes->sort();
auto& range = nodes->makeList();

Expand Down
26 changes: 26 additions & 0 deletions tests/unit/objgroup/test_objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,32 @@ TEST_F(TestObjGroup, test_proxy_invoke) {
EXPECT_EQ(proxy.get()->recv_, 3);
}

TEST_F(TestObjGroup, test_proxy_multicast) {
using namespace ::vt::group::region;
auto const this_node = theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();

auto proxy =
vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_multicast");

vt::runInEpochCollective([this_node, num_nodes, proxy] {
if (this_node == 0) {
// Create list of nodes and multicast to them
List::ListType range;
for (vt::NodeType node = 0; node < num_nodes; ++node) {
if (node % 2 == 0) {
range.push_back(node);
}
}

proxy.multicast<&MyObjA::handler>(std::make_unique<List>(range));
}
});

const auto expected = this_node % 2 == 0 ? 1 : 0;
EXPECT_EQ(proxy.get()->recv_, expected);
}

TEST_F(TestObjGroup, test_pending_send) {
auto my_node = vt::theContext()->getNode();
// create a proxy to a object group
Expand Down

0 comments on commit 7fb28eb

Please sign in to comment.