From d85e805c417b0dc2e110a71fcd62e8b7f037a6d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beat=20K=C3=BCng?= Date: Thu, 10 Oct 2024 13:16:02 +0200 Subject: [PATCH] rework graph and add more unit tests, also for pub sub class --- translation_node/CMakeLists.txt | 4 + translation_node/package.xml | 1 + translation_node/src/graph.h | 122 ++++----- translation_node/src/main.cpp | 30 +- translation_node/src/monitor.cpp | 39 +++ translation_node/src/monitor.h | 21 ++ translation_node/src/pub_sub_graph.cpp | 3 - translation_node/src/translation_util.h | 3 +- translation_node/test/graph.cpp | 316 +++++++++++++++++++-- translation_node/test/pub_sub.cpp | 350 ++++++++++++++++++++++++ 10 files changed, 761 insertions(+), 128 deletions(-) create mode 100644 translation_node/src/monitor.cpp create mode 100644 translation_node/src/monitor.h create mode 100644 translation_node/test/pub_sub.cpp diff --git a/translation_node/CMakeLists.txt b/translation_node/CMakeLists.txt index 9b7868a..29e2f1f 100644 --- a/translation_node/CMakeLists.txt +++ b/translation_node/CMakeLists.txt @@ -12,6 +12,7 @@ find_package(px4_msgs REQUIRED) find_package(px4_msgs_old REQUIRED) add_library(${PROJECT_NAME}_lib + src/monitor.cpp src/pub_sub_graph.cpp src/translations.cpp ) @@ -26,6 +27,7 @@ install(TARGETS DESTINATION lib/${PROJECT_NAME}) if(BUILD_TESTING) + find_package(std_msgs REQUIRED) find_package(ament_lint_auto REQUIRED) find_package(ament_cmake_gtest REQUIRED) ament_lint_auto_find_test_dependencies() @@ -34,10 +36,12 @@ if(BUILD_TESTING) ament_add_gtest(${PROJECT_NAME}_unit_tests test/graph.cpp test/main.cpp + test/pub_sub.cpp ) target_include_directories(${PROJECT_NAME}_unit_tests PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME}_lib) ament_target_dependencies(${PROJECT_NAME}_unit_tests + std_msgs rclcpp ) endif() diff --git a/translation_node/package.xml b/translation_node/package.xml index a34688f..91f083a 100644 --- a/translation_node/package.xml +++ b/translation_node/package.xml @@ -12,6 +12,7 @@ ament_lint_auto ament_lint_common ament_cmake_gtest + std_msgs rclcpp px4_msgs diff --git a/translation_node/src/graph.h b/translation_node/src/graph.h index 2e73841..60a60a3 100644 --- a/translation_node/src/graph.h +++ b/translation_node/src/graph.h @@ -5,14 +5,15 @@ #pragma once #include "util.h" -#include -#include -#include +#include +#include #include #include -#include -#include #include +#include +#include +#include +#include // This implements a directed graph with potential cycles used for translation. // There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and @@ -137,8 +138,6 @@ class MessageNode { NodeData _data; const size_t _index; - MessageNode* _iterating_previous{nullptr}; - bool _want_translation{false}; friend class Graph; }; @@ -195,21 +194,49 @@ class Graph { /** * @brief Translate a message node in the graph. * - * This function performs a two-pass translation of a message node in the graph. - * First, it finds the required nodes that need the translation results, and then - * it runs the translation on these nodes to prevent unnecessary message conversions. - * * @param node The message node to translate. - * @param node_requires_translation_result A callback function that determines whether a node requires the translation result. * @param on_translated A callback function that is called for translated nodes (with an updated message buffer). + * This will not be called for the provided node. */ - void translate(const MessageNodePtr& node, const std::function& node_requires_translation_result, + void translate(const MessageNodePtr& node, const std::function& on_translated) { - // Do translation in 2 passes: first, find the required nodes that require the translation results, - // then run the translation on these nodes to prevent unnecessary message conversions - // (the assumption here is that conversions are more expensive than iterating the graph) - prepareTranslation(node, node_requires_translation_result); - runTranslation(node, on_translated); + resetNodesVisited(); + + // Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm, + // while using translation nodes as barriers (only continue when all inputs are ready) + + std::queue queue; + _node_visited[node->_index] = true; + queue.push(node); + + while (!queue.empty()) { + MessageNodePtr current = queue.front(); + queue.pop(); + for (auto& translation : current->_translations) { + const bool any_output_visited = + std::any_of(translation.node->outputs().begin(), translation.node->outputs().end(), [&](const MessageNodePtr& next_node) { + return _node_visited[next_node->_index]; + }); + // If any output node has already been visited, skip this translation node (prevents translating + // backwards, from where we came from already) + if (any_output_visited) { + continue; + } + translation.node->setInputReady(translation.input_index); + // Iterate the output nodes only if the translation node is ready + if (translation.node->translate()) { + + for (auto &next_node : translation.node->outputs()) { + if (_node_visited[next_node->_index]) { + continue; + } + _node_visited[next_node->_index] = true; + on_translated(next_node); + queue.push(next_node); + } + } + } + } } std::optional findNode(const IdType& id) const { @@ -230,12 +257,10 @@ class Graph { * Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm */ void iterateBFS(const MessageNodePtr& node, const std::function& cb) { - _node_visited.resize(_nodes.size()); - std::fill(_node_visited.begin(), _node_visited.end(), false); + resetNodesVisited(); std::queue queue; _node_visited[node->_index] = true; - node->_iterating_previous = nullptr; queue.push(node); cb(node); @@ -248,7 +273,6 @@ class Graph { continue; } _node_visited[next_node->_index] = true; - next_node->_iterating_previous = current.get(); queue.push(next_node); cb(next_node); @@ -259,59 +283,11 @@ class Graph { private: - void prepareTranslation(const MessageNodePtr& node, const std::function& node_requires_translation_result) { - iterateBFS(node, [&](const MessageNodePtr& node) { - if (node_requires_translation_result(node)) { - auto* previous_node = node.get(); - while (previous_node) { - previous_node->_want_translation = true; - previous_node = previous_node->_iterating_previous; - } - } - }); - } - - void runTranslation(const MessageNodePtr& node, const std::function& on_translated) { - _node_had_update.resize(_nodes.size()); - std::fill(_node_had_update.begin(), _node_had_update.end(), false); - _node_had_update[node->_index] = true; - - iterateBFS(node, [&](const MessageNodePtr& node) { - // If there was no update for this node, there's nothing to do (i.e. want_translation is false or the - // message buffer did not change) - if (!_node_had_update[node->_index]) { - return; - } - - on_translated(node); - - if (node->_want_translation) { - node->_want_translation = false; - for (auto &translation : node->_translations) { - // Skip translation if none of the output nodes has _want_translation set or - // if any of the nodes already had an update. - // This also prevents translating 'backwards' by one step (from where we came from) - bool want_translation = false; - bool had_update = false; - for (auto &next_node: translation.node->outputs()) { - want_translation |= next_node->_want_translation; - had_update |= _node_had_update[next_node->_index]; - } - if (!want_translation || had_update) { - continue; - } - translation.node->setInputReady(translation.input_index); - if (translation.node->translate()) { - for (auto &next_node: translation.node->outputs()) { - _node_had_update[next_node->_index] = true; - } - } - } - } - }); + void resetNodesVisited() { + _node_visited.resize(_nodes.size()); + std::fill(_node_visited.begin(), _node_visited.end(), false); } std::unordered_map _nodes; std::vector _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration - std::vector _node_had_update; ///< Cached, to avoid the need to re-allocate on each iteration }; \ No newline at end of file diff --git a/translation_node/src/main.cpp b/translation_node/src/main.cpp index 4dc03b5..786435c 100644 --- a/translation_node/src/main.cpp +++ b/translation_node/src/main.cpp @@ -10,6 +10,7 @@ #include "vehicle_attitude_v3.h" #include "vehicle_local_global_position_v2.h" #include "pub_sub_graph.h" +#include "monitor.h" using namespace std::chrono_literals; @@ -19,38 +20,13 @@ class RosTranslationNode : public rclcpp::Node RosTranslationNode() : Node("translation_node") { _pub_sub_graph = std::make_unique(*this, RegisteredTranslations::instance().translations()); - - - // Monitor subscriptions & publishers - // TODO: event-based - _node_update_timer = create_wall_timer(1s, [this](){ - std::vector topic_info; - const auto topics = get_topic_names_and_types(); - for (const auto& [topic_name, topic_types] : topics) { - auto publishers = get_publishers_info_by_topic(topic_name); - auto subscribers = get_subscriptions_info_by_topic(topic_name); - // Filter out self - int num_publishers = 0; - for (const auto& publisher : publishers) { - num_publishers += publisher.node_name() != this->get_name(); - } - int num_subscribers = 0; - for (const auto& subscriber : subscribers) { - num_subscribers += subscriber.node_name() != this->get_name(); - } - - if (num_subscribers > 0 || num_publishers > 0) { - topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers}); - } - } - _pub_sub_graph->updateCurrentTopics(topic_info); - - }); + _monitor = std::make_unique(*this, *_pub_sub_graph); } private: std::unique_ptr _pub_sub_graph; rclcpp::TimerBase::SharedPtr _node_update_timer; + std::unique_ptr _monitor; }; int main(int argc, char * argv[]) diff --git a/translation_node/src/monitor.cpp b/translation_node/src/monitor.cpp new file mode 100644 index 0000000..a647742 --- /dev/null +++ b/translation_node/src/monitor.cpp @@ -0,0 +1,39 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include "monitor.h" +using namespace std::chrono_literals; + +Monitor::Monitor(rclcpp::Node &node, PubSubGraph& pub_sub_graph) + : _node(node), _pub_sub_graph(pub_sub_graph) { + + // Monitor subscriptions & publishers + // TODO: event-based + _node_update_timer = _node.create_wall_timer(1s, [this]() { + updateNow(); + }); +} + +void Monitor::updateNow() { + std::vector topic_info; + const auto topics = _node.get_topic_names_and_types(); + for (const auto& [topic_name, topic_types] : topics) { + auto publishers = _node.get_publishers_info_by_topic(topic_name); + auto subscribers = _node.get_subscriptions_info_by_topic(topic_name); + // Filter out self + int num_publishers = 0; + for (const auto& publisher : publishers) { + num_publishers += publisher.node_name() != _node.get_name(); + } + int num_subscribers = 0; + for (const auto& subscriber : subscribers) { + num_subscribers += subscriber.node_name() != _node.get_name(); + } + + if (num_subscribers > 0 || num_publishers > 0) { + topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers}); + } + } + _pub_sub_graph.updateCurrentTopics(topic_info); +} diff --git a/translation_node/src/monitor.h b/translation_node/src/monitor.h new file mode 100644 index 0000000..d3d75c0 --- /dev/null +++ b/translation_node/src/monitor.h @@ -0,0 +1,21 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include "pub_sub_graph.h" +#include + +class Monitor { +public: + explicit Monitor(rclcpp::Node &node, PubSubGraph& pub_sub_graph); + + void updateNow(); + +private: + rclcpp::Node &_node; + PubSubGraph& _pub_sub_graph; + rclcpp::TimerBase::SharedPtr _node_update_timer; +}; \ No newline at end of file diff --git a/translation_node/src/pub_sub_graph.cpp b/translation_node/src/pub_sub_graph.cpp index 863ba7e..1bec240 100644 --- a/translation_node/src/pub_sub_graph.cpp +++ b/translation_node/src/pub_sub_graph.cpp @@ -139,9 +139,6 @@ void PubSubGraph::updateCurrentTopics(const std::vector &topics) { void PubSubGraph::onSubscriptionUpdate(const Graph::MessageNodePtr& node) { _pub_sub_graph.translate( node, - [](const Graph::MessageNodePtr& node) { - return node->data().publication != nullptr; - }, [](const Graph::MessageNodePtr& node) { if (node->data().publication != nullptr) { rcl_publish(node->data().publication->get_publisher_handle().get(), diff --git a/translation_node/src/translation_util.h b/translation_node/src/translation_util.h index bee48e6..f28e21e 100644 --- a/translation_node/src/translation_util.h +++ b/translation_node/src/translation_util.h @@ -143,8 +143,9 @@ class RegisteredTranslations { const Translations& translations() const { return _translations; } -private: +protected: RegisteredTranslations() = default; +private: template Topic getTopicForMessageType(const std::string& topic_name) { diff --git a/translation_node/test/graph.cpp b/translation_node/test/graph.cpp index 7a7f25e..e0d1f36 100644 --- a/translation_node/test/graph.cpp +++ b/translation_node/test/graph.cpp @@ -59,11 +59,12 @@ TEST(graph, basic) EXPECT_TRUE(node2->data().iterated); // Test translation - graph.translate(node1, [](auto&& node) { return true; }, - [](auto&& node) { - node->data().translated = true; - }); - EXPECT_TRUE(node1->data().translated); + graph.translate(node1, + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + EXPECT_FALSE(node1->data().translated); EXPECT_TRUE(node2->data().translated); EXPECT_EQ(*buffer1, message1_value); EXPECT_EQ(*buffer2, message1_value + offset); @@ -160,24 +161,19 @@ TEST(graph, multi_path) // Translation graph.translate(graph.findNode(ids[0]).value(), - [&](auto&& node) { - // Skip the last 2 nodes - return node.get() != graph.findNode(ids[num_nodes-1]).value().get() && - node.get() != graph.findNode(ids[num_nodes-2]).value().get(); - }, [](auto&& node) { + assert(!node->data().translated); node->data().translated = true; }); - // Last 2 nodes should not be translated, the rest should be - for (unsigned i =0; i < num_nodes-2; ++i) { - EXPECT_EQ(graph.findNode(ids[i]).value()->data().translated, true); + // All nodes should be translated except the first + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + for (unsigned i = 1; i < num_nodes; ++i) { + EXPECT_EQ(graph.findNode(ids[i]).value()->data().translated, true) << "node[" << i << "]"; } - EXPECT_EQ(graph.findNode(ids[num_nodes-2]).value()->data().translated, false); - EXPECT_EQ(graph.findNode(ids[num_nodes-1]).value()->data().translated, false); // Ensure the correct edges were used for translations - EXPECT_EQ("00000000000000000000000001010001", translated.to_string()); + EXPECT_EQ("00000000000000000000100101010001", translated.to_string()); // Ensure correct translation path taken for each node (which is stored in the buffers), // and translation callback got called @@ -185,8 +181,8 @@ TEST(graph, multi_path) EXPECT_EQ(*buffers[1], 0b1); EXPECT_EQ(*buffers[2], 0b1000000); EXPECT_EQ(*buffers[3], 0b1010000); - EXPECT_EQ(*buffers[4], 0); - EXPECT_EQ(*buffers[5], 0); + EXPECT_EQ(*buffers[4], 0b100000000); + EXPECT_EQ(*buffers[5], 0b100100000000); for (unsigned i=0; i < num_nodes; ++i) { printf("node[%i]: translated: %i, buffer: %i\n", i, graph.findNode(ids[i]).value()->data().translated, @@ -286,8 +282,8 @@ TEST(graph, multi_links) { auto translate_node = [&](const MessageIdentifier& id) { graph.translate(graph.findNode(id).value(), - [&](auto&& node) { return true; }, [](auto&& node) { + assert(!node->data().translated); node->data().translated = true; }); @@ -303,7 +299,7 @@ TEST(graph, multi_links) { *buffers[1] = 0x0f00000f; translate_node(ids[1]); EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); - EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); @@ -315,7 +311,7 @@ TEST(graph, multi_links) { // Now updating node 1 should update nodes 3+6 (merging, both inputs available now) translate_node(ids[0]); - EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); @@ -328,7 +324,7 @@ TEST(graph, multi_links) { // Another update must not trigger any other updates translate_node(ids[0]); - EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); @@ -345,11 +341,283 @@ TEST(graph, multi_links) { EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); - EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); EXPECT_EQ(*buffers[0], 0x000000d0); EXPECT_EQ(*buffers[1], 0xc0000000); EXPECT_EQ(*buffers[2], 0xc00000d0); EXPECT_EQ(*buffers[3], 0); EXPECT_EQ(*buffers[4], 0xc0000000); EXPECT_EQ(*buffers[5], 0xc00000d0); -} \ No newline at end of file +} + +TEST(graph, multi_links2) { + // Multiple topics (merging / splitting) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 8; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic3", 1}, + {"topic1", 2}, + {"topic2", 2}, + {"topic1", 3}, + {"topic2", 3}, + {"topic3", 3}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ ___ + // 1 - | | | | - 6 + // | | - 4 - | | + // 2 - | | | | - 7 + // | | - 5 - | | + // 3 - | | | | - 8 + // --- --- + + // Translations + auto translation_cb_32 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 3); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto a_value3 = static_cast(a[2].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1 | *a_value2; + *b_value2 = *a_value3; + }; + auto translation_cb_23 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 3); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + auto b_value3 = static_cast(b[2].get()); + *b_value1 = *a_value1 & 0x0000ffffu; + *b_value2 = *a_value1 & 0xffff0000u; + *b_value3 = *a_value2; + }; + graph.addTranslation(translation_cb_32, {ids[0], ids[1], ids[2]}, {ids[3], ids[4]}); + graph.addTranslation(translation_cb_23, {ids[3], ids[4]}, {ids[0], ids[1], ids[2]}); + + graph.addTranslation(translation_cb_23, {ids[3], ids[4]}, {ids[5], ids[6], ids[7]}); + graph.addTranslation(translation_cb_32, {ids[5], ids[6], ids[7]}, {ids[3], ids[4]}); + + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [](auto&& node) { + assert(!node->data().translated); + node->data().translated = true; + }); + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating nodes 1+2+3 should update nodes 6+7+8 + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0f00000f; + *buffers[2] = 0x0c00000c; + translate_node(ids[1]); + translate_node(ids[0]); + translate_node(ids[2]); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[7]).value()->data().translated, true); + EXPECT_EQ(*buffers[3], 0xa00000b0 | 0x0f00000f); + EXPECT_EQ(*buffers[4], 0x0c00000c); + EXPECT_EQ(*buffers[5], (0xa00000b0 | 0x0f00000f) & 0x0000ffffu); + EXPECT_EQ(*buffers[6], (0xa00000b0 | 0x0f00000f) & 0xffff0000u); + EXPECT_EQ(*buffers[7], 0x0c00000c); + + reset_translated(); + + // Now updating nodes 6+7+8 should update nodes 1+2+3 + *buffers[5] = 0xa00000b0; + *buffers[6] = 0x0f00000f; + *buffers[7] = 0x0c00000c; + translate_node(ids[5]); + translate_node(ids[6]); + translate_node(ids[7]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(*buffers[3], 0xa00000b0 | 0x0f00000f); + EXPECT_EQ(*buffers[4], 0x0c00000c); + EXPECT_EQ(*buffers[0], (0xa00000b0 | 0x0f00000f) & 0x0000ffffu); + EXPECT_EQ(*buffers[1], (0xa00000b0 | 0x0f00000f) & 0xffff0000u); + EXPECT_EQ(*buffers[2], 0x0c00000c); +} + +TEST(graph, multi_links3) { + // Multiple topics (cannot use the shortest path) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 7; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic1", 2}, + {"topic1", 3}, + {"topic1", 4}, + {"topic2", 4}, + {"topic1", 5}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ ___ ___ ___ + // 1 - | | - 3 - | | - 4 - | | - 5 - | | - 7 + // | | --- --- | | + // | | | | + // 2 - | | --------------------- 6 - | | + // --- --- + + // Translations + auto translation_cb_21 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + *b_value1 = *a_value1 | *a_value2; + }; + auto translation_cb_22 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1; + *b_value2 = *a_value2; + }; + auto translation_cb_12 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 2); + auto a_value1 = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value1 & 0x0000ffffu; + *b_value2 = *a_value1 & 0xffff0000u; + }; + auto translation_cb_11 = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + *b_value1 = *a_value1 + 1; + }; + graph.addTranslation(translation_cb_22, {ids[0], ids[1]}, {ids[2], ids[5]}); + graph.addTranslation(translation_cb_22, {ids[2], ids[5]}, {ids[0], ids[1]}); + graph.addTranslation(translation_cb_11, {ids[2]}, {ids[3]}); + graph.addTranslation(translation_cb_11, {ids[3]}, {ids[2]}); + graph.addTranslation(translation_cb_11, {ids[3]}, {ids[4]}); + graph.addTranslation(translation_cb_11, {ids[4]}, {ids[3]}); + graph.addTranslation(translation_cb_21, {ids[4], ids[5]}, {ids[6]}); + graph.addTranslation(translation_cb_12, {ids[6]}, {ids[4], ids[5]}); + + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [](auto&& node) { + assert(!node->data().translated); + assert(!node->data().translated); + node->data().translated = true; + }); + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating nodes 1+2 should update node 7 + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0a00000b; + translate_node(ids[1]); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, false); + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(*buffers[2], 0xa00000b0); + EXPECT_EQ(*buffers[3], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[4], 0xa00000b0 + 2); + EXPECT_EQ(*buffers[5], 0x0a00000b); + EXPECT_EQ(*buffers[6], ((0xa00000b0 + 2) | 0x0a00000b)); + + reset_translated(); + + // Now updating nodes 4+6 should update the rest + *buffers[3] = 0xa00000b0; + *buffers[5] = 0x0f00000f; + translate_node(ids[3]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, false); + translate_node(ids[5]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[6]).value()->data().translated, true); + EXPECT_EQ(*buffers[0], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[1], 0x0f00000f); + EXPECT_EQ(*buffers[2], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[4], 0xa00000b0 + 1); + EXPECT_EQ(*buffers[6], (0xa00000b0 + 1) | 0x0f00000f); +} diff --git a/translation_node/test/pub_sub.cpp b/translation_node/test/pub_sub.cpp new file mode 100644 index 0000000..3c46a74 --- /dev/null +++ b/translation_node/test/pub_sub.cpp @@ -0,0 +1,350 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include +using namespace std::chrono_literals; + +// Define a custom struct with MESSAGE_VERSION field that can be used in ROS pubs and subs +#define DEFINE_VERSIONED_ROS_MESSAGE_TYPE(CUSTOM_TYPE_NAME, ROS_TYPE_NAME, THIS_MESSAGE_VERSION) \ + struct CUSTOM_TYPE_NAME : public ROS_TYPE_NAME { \ + CUSTOM_TYPE_NAME() = default; \ + CUSTOM_TYPE_NAME(const ROS_TYPE_NAME& msg) : ROS_TYPE_NAME(msg) {} \ + static constexpr uint32_t MESSAGE_VERSION = THIS_MESSAGE_VERSION; \ + }; \ + template<> \ + struct rclcpp::TypeAdapter \ + { \ + using is_specialized = std::true_type; \ + using custom_type = CUSTOM_TYPE_NAME; \ + using ros_message_type = ROS_TYPE_NAME; \ + static void convert_to_ros_message(const custom_type & source, ros_message_type & destination) \ + { \ + destination = source; \ + } \ + static void convert_to_custom(const ros_message_type & source, custom_type & destination) \ + { \ + destination = source; \ + } \ + }; \ + RCLCPP_USING_CUSTOM_TYPE_AS_ROS_MESSAGE_TYPE(CUSTOM_TYPE_NAME, ROS_TYPE_NAME); + +class PubSubGraphTest : public testing::Test +{ +protected: + void SetUp() override + { + _test_node = std::make_shared("test_node"); + _app_node = std::make_shared("app_node"); + _executor.add_node(_test_node); + _executor.add_node(_app_node); + + for (auto& node : {_app_node, _test_node}) { + auto ret = rcutils_logging_set_logger_level( + node->get_logger().get_name(), RCUTILS_LOG_SEVERITY_DEBUG); + if (ret != RCUTILS_RET_OK) { + RCLCPP_ERROR( + node->get_logger(), "Error setting severity: %s", + rcutils_get_error_string().str); + rcutils_reset_error(); + } + } + } + + bool spinWithTimeout(const std::function& predicate) { + const auto start = _app_node->now(); + while (_app_node->now() - start < 5s) { + _executor.spin_some(); + if (predicate()) { + return true; + } + } + return false; + } + + std::shared_ptr _test_node; + std::shared_ptr _app_node; + rclcpp::executors::SingleThreadedExecutor _executor; +}; + +class RegisteredTranslationsTest : public RegisteredTranslations { +public: + RegisteredTranslationsTest() = default; +}; + + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(Float32Versioned, std_msgs::msg::Float32, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(ColorRGBAVersioned, std_msgs::msg::ColorRGBA, 2u); + +class DirectTranslationTest { +public: + using MessageOlder = Float32Versioned; + using MessageNewer = ColorRGBAVersioned; + + static constexpr const char* kTopic = "test/direct_translation"; + + static void fromOlder(const MessageOlder &msg_older, MessageNewer &msg_newer) { + msg_newer.r = 1.f; + msg_newer.g = msg_older.data; + msg_newer.b = 2.f; + } + + static void toOlder(const MessageNewer &msg_newer, MessageOlder &msg_older) { + msg_older.data = msg_newer.r + msg_newer.g + msg_newer.b; + } +}; + + +TEST_F(PubSubGraphTest, DirectTranslation) +{ + RegisteredTranslationsTest registered_translations; + registered_translations.registerDirectTranslation(); + + PubSubGraph graph(*_test_node, registered_translations.translations()); + Monitor monitor(*_test_node, graph); + + const std::string topic_name = DirectTranslationTest::kTopic; + const std::string topic_name_older_version = getVersionedTopicName(topic_name, DirectTranslationTest::MessageOlder::MESSAGE_VERSION); + const std::string topic_name_newer_version = getVersionedTopicName(topic_name, DirectTranslationTest::MessageNewer::MESSAGE_VERSION); + + { + // Create publisher + subscriber + int num_topic_updates = 0; + DirectTranslationTest::MessageNewer latest_data{}; + auto publisher = _app_node->create_publisher(topic_name_older_version, + rclcpp::QoS(1).best_effort()); + auto subscriber = _app_node->create_subscription(topic_name_newer_version, + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data, this]( + DirectTranslationTest::MessageNewer::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated: %.3f", (double) msg->g); + latest_data = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&subscriber, &publisher]() { + return subscriber->get_publisher_count() > 0 && publisher->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + DirectTranslationTest::MessageOlder msg_older; + msg_older.data = (float) i; + publisher->publish(msg_older); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == i + 1; + })) << "Timeout, topic update not received, i=" << i; + + // Check data + EXPECT_FLOAT_EQ(latest_data.r, 1.f); + EXPECT_FLOAT_EQ(latest_data.g, (float) i); + EXPECT_FLOAT_EQ(latest_data.b, 2.f); + } + } + + // Now check the translation into the other direction + { + int num_topic_updates = 0; + DirectTranslationTest::MessageOlder latest_data{}; + auto publisher = _app_node->create_publisher(topic_name_newer_version, + rclcpp::QoS(1).best_effort()); + auto subscriber = _app_node->create_subscription(topic_name_older_version, + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data, this]( + DirectTranslationTest::MessageOlder::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated: %.3f", (double) msg->data); + latest_data = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&subscriber, &publisher]() { + return subscriber->get_publisher_count() > 0 && publisher->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + DirectTranslationTest::MessageNewer msg_newer; + msg_newer.r = (float)i; + msg_newer.g = (float)i * 10.f; + msg_newer.b = (float)i * 100.f; + publisher->publish(msg_newer); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == i + 1; + })) << "Timeout, topic update not received, i=" << i; + + // Check data + EXPECT_FLOAT_EQ(latest_data.data, 111.f * (float)i); + } + } +} + + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV1, std_msgs::msg::Float32, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV1, std_msgs::msg::Float64, 1u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeCV1, std_msgs::msg::Int64, 1u); + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV2, std_msgs::msg::ColorRGBA, 2u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV2, std_msgs::msg::Int64, 2u); + +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeAV3, std_msgs::msg::Float64, 3u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeBV3, std_msgs::msg::Int64, 3u); +DEFINE_VERSIONED_ROS_MESSAGE_TYPE(MessageTypeCV3, std_msgs::msg::Float32, 3u); + +class TranslationMultiTestV2 { +public: + using MessagesOlder = TypesArray; + static constexpr const char* kTopicsOlder[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + "test/multi_translation_topic_c", + }; + static_assert(MessageTypeAV1::MESSAGE_VERSION == 1); + static_assert(MessageTypeBV1::MESSAGE_VERSION == 1); + static_assert(MessageTypeCV1::MESSAGE_VERSION == 1); + + using MessagesNewer = TypesArray; + static constexpr const char* kTopicsNewer[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + }; + static_assert(MessageTypeAV2::MESSAGE_VERSION == 2); + static_assert(MessageTypeBV2::MESSAGE_VERSION == 2); + + static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, + const MessagesOlder::Type3 &msg_older3, + MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2) { + msg_newer1.r = msg_older1.data; + msg_newer1.g = (float)msg_older2.data; + msg_newer1.b = (float)msg_older3.data; + msg_newer2.data = msg_older3.data * 10; + } + static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, + MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2, MessagesOlder::Type3 &msg_older3) { + msg_older1.data = msg_newer1.r; + msg_older2.data = msg_newer1.g; + msg_older3.data = msg_newer2.data / 10; + } +}; + +class TranslationMultiTestV3 { +public: + using MessagesOlder = TypesArray; + static constexpr const char* kTopicsOlder[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + }; + + using MessagesNewer = TypesArray; + static constexpr const char* kTopicsNewer[] = { + "test/multi_translation_topic_a", + "test/multi_translation_topic_b", + "test/multi_translation_topic_c", + }; + + static void fromOlder(const MessagesOlder::Type1 &msg_older1, const MessagesOlder::Type2 &msg_older2, + MessagesNewer::Type1 &msg_newer1, MessagesNewer::Type2 &msg_newer2, MessagesNewer::Type3 &msg_newer3) { + msg_newer1.data = msg_older1.r; + msg_newer2.data = (int64_t)msg_older1.g; + msg_newer3.data = (float)msg_older2.data + 100; + } + static void toOlder(const MessagesNewer::Type1 &msg_newer1, const MessagesNewer::Type2 &msg_newer2, const MessagesNewer::Type3 &msg_newer3, + MessagesOlder::Type1 &msg_older1, MessagesOlder::Type2 &msg_older2) { + msg_older1.r = (float)msg_newer1.data; + msg_older1.g = (float)msg_newer2.data; + msg_older2.data = (int64_t)msg_newer3.data - 100; + } +}; + +TEST_F(PubSubGraphTest, TranslationMulti) { + RegisteredTranslationsTest registered_translations; + // Register 3 different message versions, with 3 types -> 2 types -> 3 types + registered_translations.registerTranslation(); + registered_translations.registerTranslation(); + + PubSubGraph graph(*_test_node, registered_translations.translations()); + Monitor monitor(*_test_node, graph); + + const std::string topic_name_a = TranslationMultiTestV2::kTopicsOlder[0]; + const std::string topic_name_b = TranslationMultiTestV2::kTopicsOlder[1]; + const std::string topic_name_c = TranslationMultiTestV2::kTopicsOlder[2]; + + // Create publishers for version 1 + subscribers for version 3 + int num_topic_updates = 0; + MessageTypeAV3 latest_data_a{}; + MessageTypeBV3 latest_data_b{}; + MessageTypeCV3 latest_data_c{}; + auto publisher_a = _app_node->create_publisher(getVersionedTopicName(topic_name_a, MessageTypeAV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto publisher_b = _app_node->create_publisher(getVersionedTopicName(topic_name_b, MessageTypeBV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto publisher_c = _app_node->create_publisher(getVersionedTopicName(topic_name_c, MessageTypeCV1::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort()); + auto subscriber_a = _app_node->create_subscription(getVersionedTopicName(topic_name_a, MessageTypeAV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_a, this]( + MessageTypeAV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (A): %.3f", (double) msg->data); + latest_data_a = *msg; + ++num_topic_updates; + }); + auto subscriber_b = _app_node->create_subscription(getVersionedTopicName(topic_name_b, MessageTypeBV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_b, this]( + MessageTypeBV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (B): %.3f", (double) msg->data); + latest_data_b = *msg; + ++num_topic_updates; + }); + auto subscriber_c = _app_node->create_subscription(getVersionedTopicName(topic_name_c, MessageTypeCV3::MESSAGE_VERSION), + rclcpp::QoS(1).best_effort(), [&num_topic_updates, &latest_data_c, this]( + MessageTypeCV3::UniquePtr msg) -> void { + RCLCPP_DEBUG(_app_node->get_logger(), "Topic updated (C): %.3f", (double) msg->data); + latest_data_c = *msg; + ++num_topic_updates; + }); + + monitor.updateNow(); + + // Wait until there is a subscriber & publisher + ASSERT_TRUE(spinWithTimeout([&]() { + return subscriber_a->get_publisher_count() > 0 && subscriber_b->get_publisher_count() > 0 && subscriber_c->get_publisher_count() > 0 && + publisher_a->get_subscription_count() > 0 && publisher_b->get_subscription_count() > 0 && publisher_c->get_subscription_count() > 0; + })) << "Timeout, no publisher/subscriber found"; + + // Publish some data & wait for it to arrive + for (int i = 0; i < 10; ++i) { + MessageTypeAV1 msg_older_a; + msg_older_a.data = (float) i; + publisher_a->publish(msg_older_a); + + MessageTypeBV1 msg_older_b; + msg_older_b.data = (float) i * 10.f; + publisher_b->publish(msg_older_b); + + MessageTypeCV1 msg_older_c; + msg_older_c.data = i * 100; + publisher_c->publish(msg_older_c); + + ASSERT_TRUE(spinWithTimeout([&num_topic_updates, i]() { + return num_topic_updates == (i + 1) * 3; + })) << "Timeout, topic update not received, i=" << i << ", num updates=" << num_topic_updates; + + // Check data + EXPECT_FLOAT_EQ(latest_data_a.data, (float)i); + EXPECT_FLOAT_EQ(latest_data_b.data, (float)i * 10.f); + EXPECT_FLOAT_EQ(latest_data_c.data, ((float)i * 100.f) * 10.f + 100.f); + } +} \ No newline at end of file