Skip to content

Commit

Permalink
rework graph and add more unit tests, also for pub sub class
Browse files Browse the repository at this point in the history
  • Loading branch information
bkueng committed Oct 10, 2024
1 parent 7d97be1 commit d85e805
Show file tree
Hide file tree
Showing 10 changed files with 761 additions and 128 deletions.
4 changes: 4 additions & 0 deletions translation_node/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ find_package(px4_msgs REQUIRED)
find_package(px4_msgs_old REQUIRED)

add_library(${PROJECT_NAME}_lib
src/monitor.cpp
src/pub_sub_graph.cpp
src/translations.cpp
)
Expand All @@ -26,6 +27,7 @@ install(TARGETS
DESTINATION lib/${PROJECT_NAME})

if(BUILD_TESTING)
find_package(std_msgs REQUIRED)
find_package(ament_lint_auto REQUIRED)
find_package(ament_cmake_gtest REQUIRED)
ament_lint_auto_find_test_dependencies()
Expand All @@ -34,10 +36,12 @@ if(BUILD_TESTING)
ament_add_gtest(${PROJECT_NAME}_unit_tests
test/graph.cpp
test/main.cpp
test/pub_sub.cpp
)
target_include_directories(${PROJECT_NAME}_unit_tests PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME}_lib)
ament_target_dependencies(${PROJECT_NAME}_unit_tests
std_msgs
rclcpp
)
endif()
Expand Down
1 change: 1 addition & 0 deletions translation_node/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_cmake_gtest</test_depend>
<test_depend>std_msgs</test_depend>

<depend>rclcpp</depend>
<depend>px4_msgs</depend>
Expand Down
122 changes: 49 additions & 73 deletions translation_node/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
#pragma once

#include "util.h"
#include <string>
#include <utility>
#include <vector>
#include <algorithm>
#include <bitset>
#include <functional>
#include <memory>
#include <bitset>
#include <queue>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include <vector>

// This implements a directed graph with potential cycles used for translation.
// There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and
Expand Down Expand Up @@ -137,8 +138,6 @@ class MessageNode {
NodeData _data;

const size_t _index;
MessageNode<NodeData, IdType>* _iterating_previous{nullptr};
bool _want_translation{false};

friend class Graph<NodeData, IdType>;
};
Expand Down Expand Up @@ -195,21 +194,49 @@ class Graph {
/**
* @brief Translate a message node in the graph.
*
* This function performs a two-pass translation of a message node in the graph.
* First, it finds the required nodes that need the translation results, and then
* it runs the translation on these nodes to prevent unnecessary message conversions.
*
* @param node The message node to translate.
* @param node_requires_translation_result A callback function that determines whether a node requires the translation result.
* @param on_translated A callback function that is called for translated nodes (with an updated message buffer).
* This will not be called for the provided node.
*/
void translate(const MessageNodePtr& node, const std::function<bool(const MessageNodePtr&)>& node_requires_translation_result,
void translate(const MessageNodePtr& node,
const std::function<void(const MessageNodePtr&)>& on_translated) {
// Do translation in 2 passes: first, find the required nodes that require the translation results,
// then run the translation on these nodes to prevent unnecessary message conversions
// (the assumption here is that conversions are more expensive than iterating the graph)
prepareTranslation(node, node_requires_translation_result);
runTranslation(node, on_translated);
resetNodesVisited();

// Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm,
// while using translation nodes as barriers (only continue when all inputs are ready)

std::queue<MessageNodePtr> queue;
_node_visited[node->_index] = true;
queue.push(node);

while (!queue.empty()) {
MessageNodePtr current = queue.front();
queue.pop();
for (auto& translation : current->_translations) {
const bool any_output_visited =
std::any_of(translation.node->outputs().begin(), translation.node->outputs().end(), [&](const MessageNodePtr& next_node) {
return _node_visited[next_node->_index];
});
// If any output node has already been visited, skip this translation node (prevents translating
// backwards, from where we came from already)
if (any_output_visited) {
continue;
}
translation.node->setInputReady(translation.input_index);
// Iterate the output nodes only if the translation node is ready
if (translation.node->translate()) {

for (auto &next_node : translation.node->outputs()) {
if (_node_visited[next_node->_index]) {
continue;
}
_node_visited[next_node->_index] = true;
on_translated(next_node);
queue.push(next_node);
}
}
}
}
}

std::optional<MessageNodePtr> findNode(const IdType& id) const {
Expand All @@ -230,12 +257,10 @@ class Graph {
* Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm
*/
void iterateBFS(const MessageNodePtr& node, const std::function<void(const MessageNodePtr&)>& cb) {
_node_visited.resize(_nodes.size());
std::fill(_node_visited.begin(), _node_visited.end(), false);
resetNodesVisited();

std::queue<MessageNodePtr> queue;
_node_visited[node->_index] = true;
node->_iterating_previous = nullptr;
queue.push(node);
cb(node);

Expand All @@ -248,7 +273,6 @@ class Graph {
continue;
}
_node_visited[next_node->_index] = true;
next_node->_iterating_previous = current.get();
queue.push(next_node);

cb(next_node);
Expand All @@ -259,59 +283,11 @@ class Graph {


private:
void prepareTranslation(const MessageNodePtr& node, const std::function<bool(const MessageNodePtr&)>& node_requires_translation_result) {
iterateBFS(node, [&](const MessageNodePtr& node) {
if (node_requires_translation_result(node)) {
auto* previous_node = node.get();
while (previous_node) {
previous_node->_want_translation = true;
previous_node = previous_node->_iterating_previous;
}
}
});
}

void runTranslation(const MessageNodePtr& node, const std::function<void(const MessageNodePtr&)>& on_translated) {
_node_had_update.resize(_nodes.size());
std::fill(_node_had_update.begin(), _node_had_update.end(), false);
_node_had_update[node->_index] = true;

iterateBFS(node, [&](const MessageNodePtr& node) {
// If there was no update for this node, there's nothing to do (i.e. want_translation is false or the
// message buffer did not change)
if (!_node_had_update[node->_index]) {
return;
}

on_translated(node);

if (node->_want_translation) {
node->_want_translation = false;
for (auto &translation : node->_translations) {
// Skip translation if none of the output nodes has _want_translation set or
// if any of the nodes already had an update.
// This also prevents translating 'backwards' by one step (from where we came from)
bool want_translation = false;
bool had_update = false;
for (auto &next_node: translation.node->outputs()) {
want_translation |= next_node->_want_translation;
had_update |= _node_had_update[next_node->_index];
}
if (!want_translation || had_update) {
continue;
}
translation.node->setInputReady(translation.input_index);
if (translation.node->translate()) {
for (auto &next_node: translation.node->outputs()) {
_node_had_update[next_node->_index] = true;
}
}
}
}
});
void resetNodesVisited() {
_node_visited.resize(_nodes.size());
std::fill(_node_visited.begin(), _node_visited.end(), false);
}

std::unordered_map<IdType, MessageNodePtr> _nodes;
std::vector<bool> _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration
std::vector<bool> _node_had_update; ///< Cached, to avoid the need to re-allocate on each iteration
};
30 changes: 3 additions & 27 deletions translation_node/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "vehicle_attitude_v3.h"
#include "vehicle_local_global_position_v2.h"
#include "pub_sub_graph.h"
#include "monitor.h"

using namespace std::chrono_literals;

Expand All @@ -19,38 +20,13 @@ class RosTranslationNode : public rclcpp::Node
RosTranslationNode() : Node("translation_node")
{
_pub_sub_graph = std::make_unique<PubSubGraph>(*this, RegisteredTranslations::instance().translations());


// Monitor subscriptions & publishers
// TODO: event-based
_node_update_timer = create_wall_timer(1s, [this](){
std::vector<PubSubGraph::TopicInfo> topic_info;
const auto topics = get_topic_names_and_types();
for (const auto& [topic_name, topic_types] : topics) {
auto publishers = get_publishers_info_by_topic(topic_name);
auto subscribers = get_subscriptions_info_by_topic(topic_name);
// Filter out self
int num_publishers = 0;
for (const auto& publisher : publishers) {
num_publishers += publisher.node_name() != this->get_name();
}
int num_subscribers = 0;
for (const auto& subscriber : subscribers) {
num_subscribers += subscriber.node_name() != this->get_name();
}

if (num_subscribers > 0 || num_publishers > 0) {
topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers});
}
}
_pub_sub_graph->updateCurrentTopics(topic_info);

});
_monitor = std::make_unique<Monitor>(*this, *_pub_sub_graph);
}

private:
std::unique_ptr<PubSubGraph> _pub_sub_graph;
rclcpp::TimerBase::SharedPtr _node_update_timer;
std::unique_ptr<Monitor> _monitor;
};

int main(int argc, char * argv[])
Expand Down
39 changes: 39 additions & 0 deletions translation_node/src/monitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#include "monitor.h"
using namespace std::chrono_literals;

Monitor::Monitor(rclcpp::Node &node, PubSubGraph& pub_sub_graph)
: _node(node), _pub_sub_graph(pub_sub_graph) {

// Monitor subscriptions & publishers
// TODO: event-based
_node_update_timer = _node.create_wall_timer(1s, [this]() {
updateNow();
});
}

void Monitor::updateNow() {
std::vector<PubSubGraph::TopicInfo> topic_info;
const auto topics = _node.get_topic_names_and_types();
for (const auto& [topic_name, topic_types] : topics) {
auto publishers = _node.get_publishers_info_by_topic(topic_name);
auto subscribers = _node.get_subscriptions_info_by_topic(topic_name);
// Filter out self
int num_publishers = 0;
for (const auto& publisher : publishers) {
num_publishers += publisher.node_name() != _node.get_name();
}
int num_subscribers = 0;
for (const auto& subscriber : subscribers) {
num_subscribers += subscriber.node_name() != _node.get_name();
}

if (num_subscribers > 0 || num_publishers > 0) {
topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers});
}
}
_pub_sub_graph.updateCurrentTopics(topic_info);
}
21 changes: 21 additions & 0 deletions translation_node/src/monitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/****************************************************************************
* Copyright (c) 2024 PX4 Development Team.
* SPDX-License-Identifier: BSD-3-Clause
****************************************************************************/
#pragma once

#include <rclcpp/rclcpp.hpp>
#include "pub_sub_graph.h"
#include <functional>

class Monitor {
public:
explicit Monitor(rclcpp::Node &node, PubSubGraph& pub_sub_graph);

void updateNow();

private:
rclcpp::Node &_node;
PubSubGraph& _pub_sub_graph;
rclcpp::TimerBase::SharedPtr _node_update_timer;
};
3 changes: 0 additions & 3 deletions translation_node/src/pub_sub_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ void PubSubGraph::updateCurrentTopics(const std::vector<TopicInfo> &topics) {
void PubSubGraph::onSubscriptionUpdate(const Graph<NodeDataPubSub>::MessageNodePtr& node) {
_pub_sub_graph.translate(
node,
[](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
return node->data().publication != nullptr;
},
[](const Graph<NodeDataPubSub>::MessageNodePtr& node) {
if (node->data().publication != nullptr) {
rcl_publish(node->data().publication->get_publisher_handle().get(),
Expand Down
3 changes: 2 additions & 1 deletion translation_node/src/translation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ class RegisteredTranslations {

const Translations& translations() const { return _translations; }

private:
protected:
RegisteredTranslations() = default;
private:

template<typename RosMessageType>
Topic getTopicForMessageType(const std::string& topic_name) {
Expand Down
Loading

0 comments on commit d85e805

Please sign in to comment.