diff --git a/translation_node/CMakeLists.txt b/translation_node/CMakeLists.txt
index 3d8b06eb..9b7868a1 100644
--- a/translation_node/CMakeLists.txt
+++ b/translation_node/CMakeLists.txt
@@ -11,14 +11,35 @@ find_package(rclcpp REQUIRED)
find_package(px4_msgs REQUIRED)
find_package(px4_msgs_old REQUIRED)
-add_executable(translation_node
- src/main.cpp
- src/ros_translations.cpp
+add_library(${PROJECT_NAME}_lib
+ src/pub_sub_graph.cpp
src/translations.cpp
)
-ament_target_dependencies(translation_node rclcpp px4_msgs px4_msgs_old)
+ament_target_dependencies(${PROJECT_NAME}_lib rclcpp px4_msgs px4_msgs_old)
+add_executable(${PROJECT_NAME}
+ src/main.cpp
+)
+target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_lib)
+ament_target_dependencies(${PROJECT_NAME} rclcpp px4_msgs px4_msgs_old)
install(TARGETS
- translation_node
+ ${PROJECT_NAME}
DESTINATION lib/${PROJECT_NAME})
+if(BUILD_TESTING)
+ find_package(ament_lint_auto REQUIRED)
+ find_package(ament_cmake_gtest REQUIRED)
+ ament_lint_auto_find_test_dependencies()
+
+ # Unit tests
+ ament_add_gtest(${PROJECT_NAME}_unit_tests
+ test/graph.cpp
+ test/main.cpp
+ )
+ target_include_directories(${PROJECT_NAME}_unit_tests PRIVATE ${CMAKE_CURRENT_LIST_DIR})
+ target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME}_lib)
+ ament_target_dependencies(${PROJECT_NAME}_unit_tests
+ rclcpp
+ )
+endif()
+
ament_package()
diff --git a/translation_node/package.xml b/translation_node/package.xml
index e212547b..a34688f3 100644
--- a/translation_node/package.xml
+++ b/translation_node/package.xml
@@ -11,6 +11,7 @@
ament_lint_auto
ament_lint_common
+ ament_cmake_gtest
rclcpp
px4_msgs
diff --git a/translation_node/src/graph.h b/translation_node/src/graph.h
new file mode 100644
index 00000000..2e73841c
--- /dev/null
+++ b/translation_node/src/graph.h
@@ -0,0 +1,317 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
+#pragma once
+
+#include "util.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// This implements a directed graph with potential cycles used for translation.
+// There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and
+// translations. Translation nodes are always in between message nodes, and can have N input messages
+// and M output messages.
+
+struct MessageIdentifier {
+ std::string topic_name;
+ MessageVersionType version;
+
+ bool operator==(const MessageIdentifier& other) const {
+ return topic_name == other.topic_name && version == other.version;
+ }
+ bool operator!=(const MessageIdentifier& other) const {
+ return !(*this == other);
+ }
+};
+
+template<>
+struct std::hash
+{
+ std::size_t operator()(const MessageIdentifier& s) const noexcept
+ {
+ std::size_t h1 = std::hash{}(s.topic_name);
+ std::size_t h2 = std::hash{}(s.version);
+ return h1 ^ (h2 << 1);
+ }
+};
+
+
+using MessageBuffer = std::shared_ptr;
+
+template
+class MessageNode;
+template
+class Graph;
+
+template
+using MessageNodePtrT = std::shared_ptr>;
+
+template
+class TranslationNode {
+public:
+ using TranslationCB = std::function&, std::vector&)>;
+
+ TranslationNode(std::vector> inputs,
+ std::vector> outputs,
+ TranslationCB translation_db)
+ : _inputs(std::move(inputs)), _outputs(std::move(outputs)), _translation_cb(std::move(translation_db)) {
+ assert(_inputs.size() <= kMaxNumInputs);
+
+ _input_buffers.resize(_inputs.size());
+ for (unsigned i = 0; i < _inputs.size(); ++i) {
+ _input_buffers[i] = _inputs[i]->buffer();
+ }
+
+ _output_buffers.resize(_outputs.size());
+ for (unsigned i = 0; i < _outputs.size(); ++i) {
+ _output_buffers[i] = _outputs[i]->buffer();
+ }
+ }
+
+ void setInputReady(unsigned index) {
+ _inputs_ready.set(index);
+ }
+
+ bool translate() {
+ if (_inputs_ready.count() == _input_buffers.size()) {
+ _translation_cb(_input_buffers, _output_buffers);
+ _inputs_ready.reset();
+ return true;
+ }
+ return false;
+ }
+
+ const std::vector>& inputs() const { return _inputs; }
+ const std::vector>& outputs() const { return _outputs; }
+
+private:
+ static constexpr int kMaxNumInputs = 32;
+
+ const std::vector> _inputs;
+ std::vector _input_buffers; ///< Cached buffers from _inputs.buffer()
+ const std::vector> _outputs;
+ std::vector _output_buffers;
+ const TranslationCB _translation_cb;
+
+ std::bitset _inputs_ready;
+};
+
+template
+using TranslationNodePtrT = std::shared_ptr>;
+
+
+template
+class MessageNode {
+public:
+
+ explicit MessageNode(NodeData node_data, size_t index, MessageBuffer message_buffer)
+ : _buffer(std::move(message_buffer)), _data(std::move(node_data)), _index(index) {}
+
+ MessageBuffer& buffer() { return _buffer; }
+
+ void addTranslationInput(TranslationNodePtrT node, unsigned input_index) {
+ _translations.push_back(Translation{std::move(node), input_index});
+ }
+
+ NodeData& data() { return _data; }
+
+ void resetNodes() {
+ _translations.clear();
+ }
+
+private:
+ struct Translation {
+ TranslationNodePtrT node; ///< Counterpart to the TranslationNode::_inputs
+ unsigned input_index; ///< Index into the TranslationNode::_inputs
+ };
+ MessageBuffer _buffer;
+ std::vector _translations;
+
+ NodeData _data;
+
+ const size_t _index;
+ MessageNode* _iterating_previous{nullptr};
+ bool _want_translation{false};
+
+ friend class Graph;
+};
+
+template
+class Graph {
+public:
+ using MessageNodePtr = MessageNodePtrT;
+
+ ~Graph() {
+ // Explicitly reset the nodes array to break up potential cycles and prevent memory leaks
+ for (auto& [id, node] : _nodes) {
+ node->resetNodes();
+ }
+ }
+
+ /**
+ * @brief Add a message node if it does not exist already
+ */
+ bool addNodeIfNotExists(const IdType& id, NodeData node_data, const MessageBuffer& message_buffer) {
+ if (_nodes.find(id) != _nodes.end()) {
+ return false;
+ }
+ // Node that we cannot remove nodes due to using the index as an array index
+ const size_t index = _nodes.size();
+ _nodes.insert({id, std::make_shared>(std::move(node_data), index, message_buffer)});
+ return true;
+ }
+
+ /**
+ * @brief Add a translation edge with N inputs and M output nodes. All nodes must already exist.
+ */
+ void addTranslation(const typename TranslationNode::TranslationCB& translation_cb,
+ const std::vector& inputs, const std::vector& outputs) {
+ auto init = [this](const std::vector& from, std::vector>& to) {
+ for (unsigned i=0; i < from.size(); ++i) {
+ auto node_iter = _nodes.find(from[i]);
+ assert(node_iter != _nodes.end());
+ to[i] = node_iter->second;
+ }
+ };
+ std::vector> input_nodes(inputs.size());
+ init(inputs, input_nodes);
+ std::vector> output_nodes(outputs.size());
+ init(outputs, output_nodes);
+
+ auto translation_node = std::make_shared>(std::move(input_nodes), std::move(output_nodes), translation_cb);
+ for (unsigned i=0; i < translation_node->inputs().size(); ++i) {
+ translation_node->inputs()[i]->addTranslationInput(translation_node, i);
+ }
+ }
+
+
+ /**
+ * @brief Translate a message node in the graph.
+ *
+ * This function performs a two-pass translation of a message node in the graph.
+ * First, it finds the required nodes that need the translation results, and then
+ * it runs the translation on these nodes to prevent unnecessary message conversions.
+ *
+ * @param node The message node to translate.
+ * @param node_requires_translation_result A callback function that determines whether a node requires the translation result.
+ * @param on_translated A callback function that is called for translated nodes (with an updated message buffer).
+ */
+ void translate(const MessageNodePtr& node, const std::function& node_requires_translation_result,
+ const std::function& on_translated) {
+ // Do translation in 2 passes: first, find the required nodes that require the translation results,
+ // then run the translation on these nodes to prevent unnecessary message conversions
+ // (the assumption here is that conversions are more expensive than iterating the graph)
+ prepareTranslation(node, node_requires_translation_result);
+ runTranslation(node, on_translated);
+ }
+
+ std::optional findNode(const IdType& id) const {
+ auto iter = _nodes.find(id);
+ if (iter == _nodes.end()) {
+ return std::nullopt;
+ }
+ return iter->second;
+ }
+
+ void iterateNodes(const std::function& cb) const {
+ for (const auto& [id, node] : _nodes) {
+ cb(id, node);
+ }
+ }
+
+ /**
+ * Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm
+ */
+ void iterateBFS(const MessageNodePtr& node, const std::function& cb) {
+ _node_visited.resize(_nodes.size());
+ std::fill(_node_visited.begin(), _node_visited.end(), false);
+
+ std::queue queue;
+ _node_visited[node->_index] = true;
+ node->_iterating_previous = nullptr;
+ queue.push(node);
+ cb(node);
+
+ while (!queue.empty()) {
+ MessageNodePtr current = queue.front();
+ queue.pop();
+ for (auto& translation : current->_translations) {
+ for (auto& next_node : translation.node->outputs()) {
+ if (_node_visited[next_node->_index]) {
+ continue;
+ }
+ _node_visited[next_node->_index] = true;
+ next_node->_iterating_previous = current.get();
+ queue.push(next_node);
+
+ cb(next_node);
+ }
+ }
+ }
+ }
+
+
+private:
+ void prepareTranslation(const MessageNodePtr& node, const std::function& node_requires_translation_result) {
+ iterateBFS(node, [&](const MessageNodePtr& node) {
+ if (node_requires_translation_result(node)) {
+ auto* previous_node = node.get();
+ while (previous_node) {
+ previous_node->_want_translation = true;
+ previous_node = previous_node->_iterating_previous;
+ }
+ }
+ });
+ }
+
+ void runTranslation(const MessageNodePtr& node, const std::function& on_translated) {
+ _node_had_update.resize(_nodes.size());
+ std::fill(_node_had_update.begin(), _node_had_update.end(), false);
+ _node_had_update[node->_index] = true;
+
+ iterateBFS(node, [&](const MessageNodePtr& node) {
+ // If there was no update for this node, there's nothing to do (i.e. want_translation is false or the
+ // message buffer did not change)
+ if (!_node_had_update[node->_index]) {
+ return;
+ }
+
+ on_translated(node);
+
+ if (node->_want_translation) {
+ node->_want_translation = false;
+ for (auto &translation : node->_translations) {
+ // Skip translation if none of the output nodes has _want_translation set or
+ // if any of the nodes already had an update.
+ // This also prevents translating 'backwards' by one step (from where we came from)
+ bool want_translation = false;
+ bool had_update = false;
+ for (auto &next_node: translation.node->outputs()) {
+ want_translation |= next_node->_want_translation;
+ had_update |= _node_had_update[next_node->_index];
+ }
+ if (!want_translation || had_update) {
+ continue;
+ }
+ translation.node->setInputReady(translation.input_index);
+ if (translation.node->translate()) {
+ for (auto &next_node: translation.node->outputs()) {
+ _node_had_update[next_node->_index] = true;
+ }
+ }
+ }
+ }
+ });
+ }
+
+ std::unordered_map _nodes;
+ std::vector _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration
+ std::vector _node_had_update; ///< Cached, to avoid the need to re-allocate on each iteration
+};
\ No newline at end of file
diff --git a/translation_node/src/main.cpp b/translation_node/src/main.cpp
index 6f2442c2..7a4d65d8 100644
--- a/translation_node/src/main.cpp
+++ b/translation_node/src/main.cpp
@@ -1,25 +1,29 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#include
#include
#include "vehicle_attitude_v2.h"
#include "vehicle_attitude_v3.h"
-#include "ros_translations.h"
+#include "pub_sub_graph.h"
using namespace std::chrono_literals;
-class TranslationNode : public rclcpp::Node
+class RosTranslationNode : public rclcpp::Node
{
public:
- TranslationNode() : Node("translation_node")
+ RosTranslationNode() : Node("translation_node")
{
- _ros_translations = std::make_unique(*this, RegisteredTranslations::instance().translations());
+ _pub_sub_graph = std::make_unique(*this, RegisteredTranslations::instance().translations());
// Monitor subscriptions & publishers
// TODO: event-based
_node_update_timer = create_wall_timer(1s, [this](){
- std::vector topic_info;
+ std::vector topic_info;
const auto topics = get_topic_names_and_types();
for (const auto& [topic_name, topic_types] : topics) {
auto publishers = get_publishers_info_by_topic(topic_name);
@@ -35,22 +39,23 @@ class TranslationNode : public rclcpp::Node
}
if (num_subscribers > 0 || num_publishers > 0) {
- topic_info.emplace_back(RosTranslations::TopicInfo{topic_name, num_subscribers, num_publishers});
+ topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers});
}
}
- _ros_translations->updateCurrentTopics(topic_info);
+ _pub_sub_graph->updateCurrentTopics(topic_info);
+
});
}
private:
- std::unique_ptr _ros_translations;
+ std::unique_ptr _pub_sub_graph;
rclcpp::TimerBase::SharedPtr _node_update_timer;
};
int main(int argc, char * argv[])
{
rclcpp::init(argc, argv);
- rclcpp::spin(std::make_shared());
+ rclcpp::spin(std::make_shared());
rclcpp::shutdown();
return 0;
}
\ No newline at end of file
diff --git a/translation_node/src/pub_sub_graph.cpp b/translation_node/src/pub_sub_graph.cpp
new file mode 100644
index 00000000..7657b592
--- /dev/null
+++ b/translation_node/src/pub_sub_graph.cpp
@@ -0,0 +1,193 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
+#include "pub_sub_graph.h"
+#include "util.h"
+
+PubSubGraph::PubSubGraph(rclcpp::Node &node, const Translations &translations) : _node(node) {
+
+ std::unordered_map> known_versions;
+
+ for (const auto& topic : translations.topics()) {
+ const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), topic.topic_name);
+ _known_topics_warned.insert({full_topic_name, false});
+
+ const MessageIdentifier id{full_topic_name, topic.version};
+ NodeDataPubSub node_data{topic.subscription_factory, topic.publication_factory, id, topic.max_serialized_message_size};
+ _pub_sub_graph.addNodeIfNotExists(id, std::move(node_data), topic.message_buffer);
+ known_versions[full_topic_name].insert(id.version);
+ }
+
+ for (const auto& translation : translations.translations()) {
+ std::vector inputs = translation.inputs;
+ for (auto& input : inputs) {
+ input.topic_name = getFullTopicName(_node.get_effective_namespace(), input.topic_name);
+ }
+ std::vector outputs = translation.outputs;
+ for (auto& output : outputs) {
+ output.topic_name = getFullTopicName(_node.get_effective_namespace(), output.topic_name);
+ }
+ _pub_sub_graph.addTranslation(translation.cb, inputs, outputs);
+ }
+
+ printTopicInfo(known_versions);
+ handleLargestTopic(known_versions);
+}
+
+void PubSubGraph::updateCurrentTopics(const std::vector &topics) {
+
+ _pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph::MessageNodePtr& node) {
+ node->data().has_external_publisher = false;
+ node->data().has_external_subscriber = false;
+ });
+
+ for (const auto& info : topics) {
+ const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.topic_name);
+ auto maybe_node = _pub_sub_graph.findNode({non_versioned_topic_name, version});
+ if (!maybe_node) {
+ auto known_topic_iter = _known_topics_warned.find(non_versioned_topic_name);
+ if (known_topic_iter != _known_topics_warned.end() && !known_topic_iter->second) {
+ RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of topic %s", version, non_versioned_topic_name.c_str());
+ known_topic_iter->second = true;
+ }
+ continue;
+ }
+ const auto& node = maybe_node.value();
+
+ if (info.num_publishers > 0) {
+ node->data().has_external_publisher = true;
+ }
+ if (info.num_subscribers > 0) {
+ node->data().has_external_subscriber = true;
+ }
+ }
+
+ // Iterate connected graph segments
+ _pub_sub_graph.iterateNodes([this](const MessageIdentifier& type, const Graph::MessageNodePtr& node) {
+ if (node->data().visited) {
+ return;
+ }
+ node->data().visited = true;
+
+ // Count the number of external subscribers and publishers for each connected graph
+ int num_publishers = 0;
+ int num_subscribers = 0;
+ int num_subscribers_without_publisher = 0;
+
+ _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) {
+ if (node->data().has_external_publisher) {
+ ++num_publishers;
+ }
+ if (node->data().has_external_subscriber) {
+ ++num_subscribers;
+ if (!node->data().has_external_publisher) {
+ ++num_subscribers_without_publisher;
+ }
+ }
+ });
+
+ // We need to instantiate publishers and subscribers if:
+ // - there are multiple publishers and at least 1 subscriber
+ // - there is 1 publisher and at least 1 subscriber on another node
+ // Note that in case of splitting or merging topics, this might create more entities than actually needed
+ const bool require_translation = (num_publishers >= 2 && num_subscribers >= 1)
+ || (num_publishers == 1 && num_subscribers_without_publisher >= 1);
+ if (require_translation) {
+ _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) {
+ // Has subscriber(s)?
+ if (node->data().has_external_subscriber && !node->data().publication) {
+ RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", node->data().topic_name.c_str(), node->data().version);
+ node->data().publication = node->data().publication_factory(_node);
+ } else if (!node->data().has_external_subscriber && node->data().publication) {
+ RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", node->data().topic_name.c_str(), node->data().version);
+ node->data().publication.reset();
+ }
+ // Has publisher(s)?
+ if (node->data().has_external_publisher && !node->data().subscription) {
+ RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", node->data().topic_name.c_str(), node->data().version);
+ node->data().subscription = node->data().subscription_factory(_node, [this, node_cpy=node]() {
+ onSubscriptionUpdate(node_cpy);
+ });
+ } else if (!node->data().has_external_publisher && node->data().subscription) {
+ RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", node->data().topic_name.c_str(), node->data().version);
+ node->data().subscription.reset();
+ }
+ });
+
+ } else {
+ // Reset any publishers or subscribers
+ _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) {
+ if (node->data().publication) {
+ RCLCPP_INFO(_node.get_logger(), "Removing publisher for topic '%s', version: %i",
+ node->data().topic_name.c_str(), node->data().version);
+ node->data().publication.reset();
+ }
+ if (node->data().subscription) {
+ RCLCPP_INFO(_node.get_logger(), "Removing subscriber for topic '%s', version: %i",
+ node->data().topic_name.c_str(), node->data().version);
+ node->data().subscription.reset();
+ }
+ });
+ }
+ });
+ _pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph::MessageNodePtr& node) {
+ node->data().visited = false;
+ });
+}
+
+void PubSubGraph::onSubscriptionUpdate(const Graph::MessageNodePtr& node) {
+ _pub_sub_graph.translate(
+ node,
+ [](const Graph::MessageNodePtr& node) {
+ return node->data().publication != nullptr;
+ },
+ [](const Graph::MessageNodePtr& node) {
+ if (node->data().publication != nullptr) {
+ rcl_publish(node->data().publication->get_publisher_handle().get(),
+ node->buffer().get(), nullptr);
+ }
+ });
+
+}
+
+void PubSubGraph::printTopicInfo(const std::unordered_map>& known_versions) const {
+ // Print info about known versions
+ RCLCPP_INFO(_node.get_logger(), "Registered pub/sub topics and versions:");
+ for (const auto& [topic_name, version_set] : known_versions) {
+ if (version_set.empty()) {
+ continue;
+ }
+ const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(),
+ std::to_string(*version_set.begin()), // start with first element
+ [](std::string a, auto&& b) {
+ return std::move(a) + ", " + std::to_string(b);
+ });
+ RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str());
+ }
+}
+
+
+void PubSubGraph::handleLargestTopic(const std::unordered_map> &known_versions) {
+ // FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given
+ // type. The information that is relevant for us is the maximum serialized message size.
+ // Since different versions can have different sizes, we need to ensure the first publication or subscription
+ // happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered.
+ // And the type must continue to be in use (so we cannot delete it)
+ for (const auto& [topic_name, versions] : known_versions) {
+ size_t max_serialized_message_size = 0;
+ const PublicationFactoryCB* publication_factory_for_max = nullptr;
+ for (auto version : versions) {
+ const auto& node = _pub_sub_graph.findNode(MessageIdentifier{topic_name, version});
+ assert(node);
+ const auto& node_data = node.value()->data();
+ if (node_data.max_serialized_message_size > max_serialized_message_size) {
+ max_serialized_message_size = node_data.max_serialized_message_size;
+ publication_factory_for_max = &node_data.publication_factory;
+ }
+ }
+ if (publication_factory_for_max) {
+ _largest_topic_publications.emplace_back((*publication_factory_for_max)(_node));
+ }
+ }
+}
\ No newline at end of file
diff --git a/translation_node/src/pub_sub_graph.h b/translation_node/src/pub_sub_graph.h
new file mode 100644
index 00000000..8294baec
--- /dev/null
+++ b/translation_node/src/pub_sub_graph.h
@@ -0,0 +1,58 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
+#pragma once
+
+#include
+#include
+#include "translations.h"
+#include "translation_util.h"
+#include "graph.h"
+
+class PubSubGraph {
+public:
+ struct TopicInfo {
+ std::string topic_name; ///< fully qualified topic name (with namespace)
+ int num_subscribers; ///< does not include this node's subscribers
+ int num_publishers; ///< does not include this node's publishers
+ };
+
+ PubSubGraph(rclcpp::Node& node, const Translations& translations);
+
+ void updateCurrentTopics(const std::vector& topics);
+
+private:
+ struct NodeDataPubSub {
+ explicit NodeDataPubSub(SubscriptionFactoryCB subscription_factory, PublicationFactoryCB publication_factory,
+ const MessageIdentifier& id, size_t max_serialized_message_size)
+ : subscription_factory(std::move(subscription_factory)), publication_factory(std::move(publication_factory)),
+ topic_name(id.topic_name), version(id.version), max_serialized_message_size(max_serialized_message_size)
+ { }
+
+ const SubscriptionFactoryCB subscription_factory;
+ const PublicationFactoryCB publication_factory;
+ const std::string topic_name;
+ const MessageVersionType version;
+ const size_t max_serialized_message_size;
+
+ // Keep track if there's currently a publisher/subscriber
+ bool has_external_publisher{false};
+ bool has_external_subscriber{false};
+
+ rclcpp::SubscriptionBase::SharedPtr subscription;
+ rclcpp::PublisherBase::SharedPtr publication;
+
+ bool visited{false};
+ };
+
+ void onSubscriptionUpdate(const Graph::MessageNodePtr& node);
+ void printTopicInfo(const std::unordered_map>& known_versions) const;
+ void handleLargestTopic(const std::unordered_map>& known_versions);
+
+ rclcpp::Node& _node;
+ Graph _pub_sub_graph;
+ std::unordered_map _known_topics_warned;
+
+ std::vector _largest_topic_publications;
+};
\ No newline at end of file
diff --git a/translation_node/src/ros_translations.cpp b/translation_node/src/ros_translations.cpp
deleted file mode 100644
index c5a055c7..00000000
--- a/translation_node/src/ros_translations.cpp
+++ /dev/null
@@ -1,189 +0,0 @@
-#include "ros_translations.h"
-
-static std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) {
- std::string full_topic_name = topic_name;
- if (!full_topic_name.empty() && full_topic_name[0] != '/') {
- if (namespace_name.empty() || namespace_name.back() != '/') {
- full_topic_name = '/' + full_topic_name;
- }
- full_topic_name = namespace_name + full_topic_name;
- }
- return full_topic_name;
-}
-
-RosTranslationForTopic::RosTranslationForTopic(rclcpp::Node& node, std::string topic_name, const TranslationForTopic &translation)
- : _node(node), _topic_name(std::move(topic_name)) {
- assert(!translation.directTranslations().empty());
- // Find oldest version
- auto version_iter = std::min_element(translation.directTranslations().begin(), translation.directTranslations().end(),
- [](auto && a, auto&& b) {
- return a.newer.version < b.newer.version;
- });
- // Build version chain, use the newer element as the value for the node
- // So we need to add the older element for the lowest version first
- {
- VersionEntry entry{version_iter->older};
- _versions.push_back(std::move(entry));
- }
-
- while (version_iter != translation.directTranslations().end()) {
- // Add current version
- VersionEntry entry{version_iter->newer};
- entry.translation_cb_to_older = version_iter->translation_cb_to_older;
- entry.translation_cb_from_older = version_iter->translation_cb_from_older;
- _versions.push_back(std::move(entry));
-
- // Find next version (assumes there is no cycle in the definitions, otherwise there's an endless loop)
- version_iter = std::find_if(translation.directTranslations().begin(), translation.directTranslations().end(),
- [version_iter](auto&& a) {
- return version_iter->newer.version == a.older.version;
- });
- }
-
- if (_versions.size() != translation.directTranslations().size() + 1) {
- // This means there is a gap in the versions
- throw std::runtime_error(std::string("non-continuous versions for topic") + _topic_name);
- }
-
- handleLargestTopic();
-}
-
-void RosTranslationForTopic::updateSubsAndPubs() {
- const bool has_subscriber = std::any_of(_versions.begin(), _versions.end(), [](auto&& a) {
- return a.has_external_subscriber;
- });
- const bool has_publisher = std::any_of(_versions.begin(), _versions.end(), [](auto&& a) {
- return a.has_external_publisher;
- });
- if (has_subscriber && has_publisher) {
- // TODO: do not add anything if there's only a subscriber & publisher for a specific version
-
- for (unsigned index = 0; index < _versions.size(); ++index) {
- auto& version = _versions[index];
- // Has subscriber(s)?
- if (version.has_external_subscriber && !version.publication) {
- RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", _topic_name.c_str(), version.version);
- version.publication = version.publication_factory(_node);
- } else if (!version.has_external_subscriber && version.publication) {
- RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", _topic_name.c_str(), version.version);
- version.publication.reset();
- }
- // Has publisher(s)?
- if (version.has_external_publisher && !version.subscription) {
- RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", _topic_name.c_str(), version.version);
- version.subscription = version.subscription_factory(_node, [this, index](void* data) {
- onSubscriptionUpdated(index, data);
- });
- } else if (!version.has_external_publisher && version.subscription) {
- RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", _topic_name.c_str(), version.version);
- version.subscription.reset();
- }
-
- }
- } else {
- // Clear all
- for (auto& version : _versions) {
- version.publication.reset();
- version.subscription.reset();
- }
- }
-}
-
-void RosTranslationForTopic::onSubscriptionUpdated(unsigned version_index, void *data) {
- const auto& entry = _versions[version_index];
- const auto lowest_publisher_iter = std::find_if(_versions.begin(), _versions.end(), [](auto&& a) { return a.publication != nullptr; });
- // Convert to lower versions
- if (lowest_publisher_iter != _versions.end()) {
- const unsigned lowest_index = std::distance(_versions.begin(), lowest_publisher_iter);
- void* current_data = data;
- for (unsigned index = version_index; index > lowest_index; --index) {
- // Convert message
- _versions[index].translation_cb_to_older(current_data, _versions[index-1].message_buffer.get());
- current_data = _versions[index-1].message_buffer.get();
- // Publish if there is a publisher
- if (_versions[index-1].publication) {
- rcl_publish(_versions[index-1].publication->get_publisher_handle().get(), current_data, nullptr);
- }
- }
- }
- // Convert to higher versions
- const auto highest_publisher_iter = std::find_if(_versions.rbegin(), _versions.rend(), [](auto&& a) { return a.publication != nullptr; });
- if (highest_publisher_iter != _versions.rend()) {
- const unsigned highest_index = std::distance(highest_publisher_iter, _versions.rend()) - 1;
- void* current_data = data;
- for (unsigned index = version_index; index < highest_index; ++index) {
- // Convert message
- _versions[index + 1].translation_cb_from_older(current_data, _versions[index+1].message_buffer.get());
- current_data = _versions[index+1].message_buffer.get();
- // Publish if there is a publisher
- if (_versions[index+1].publication) {
- rcl_publish(_versions[index+1].publication->get_publisher_handle().get(), current_data, nullptr);
- }
- }
- }
-}
-
-void RosTranslationForTopic::handleLargestTopic() {
- // FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given
- // type. The information that is relevant for us is the maximum serialized message size.
- // Since different versions can have different sizes, we need to ensure the first publication or subscription
- // happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered.
- // And the type must continue to be in use (so we cannot delete it)
- auto version_iter = std::max_element(_versions.begin(), _versions.end(),
- [](auto && a, auto&& b) {
- return a.max_serialized_message_size < b.max_serialized_message_size;
- });
- _publisher_largest_topic = version_iter->publication_factory(_node);
-}
-
-RosTranslations::RosTranslations(rclcpp::Node &node, const Translations &translations)
- : _node(node) {
-
- for (const auto& [topic_name, translation]: translations.topicTranslations()) {
- const std::string full_topic_name = getFullTopicName(node.get_effective_namespace(), topic_name);
- auto [iter, _] = _topics.emplace(full_topic_name, RosTranslationForTopic{_node, full_topic_name, translation});
-
- // Print versions info
- const std::string versions = std::accumulate(std::next(iter->second.versions().begin()), iter->second.versions().end(),
- std::to_string(iter->second.versions()[0].version), // start with first element
- [](std::string a, auto&& b) {
- return std::move(a) + ", " + std::to_string(b.version);
- });
- RCLCPP_INFO(_node.get_logger(), "Versions for topic '%s': %s", iter->second.topicName().c_str(), versions.c_str());
- }
-}
-
-void RosTranslations::updateCurrentTopics(const std::vector &topics) {
- for (auto& [topic_name, translation] : _topics) {
- translation.resetExternalSubPub();
- }
- for (const auto& topic_info : topics) {
- const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(topic_info.topic_name);
- auto iter = _topics.find(non_versioned_topic_name);
- if (iter == _topics.end()) {
- continue;
- }
- // It's a topic we're interested in, find the version
- bool found_version = false;
- for (auto& entry : iter->second.versions()) {
- if (entry.version == version) {
- found_version = true;
- if (topic_info.num_publishers > 0) {
- entry.has_external_publisher = true;
- }
- if (topic_info.num_subscribers > 0) {
- entry.has_external_subscriber = true;
- }
- }
- }
- if (!found_version && !iter->second.getAndSetErrorPrinted()) {
- RCLCPP_WARN(_node.get_logger(), "Unsupported version for topic '%s': %i", non_versioned_topic_name.c_str(), version);
- }
- }
-
- // Now update the subscriptions / publishers depending on what we found
- for (auto& [topic_name, translation] : _topics) {
- translation.updateSubsAndPubs();
- }
-}
-
diff --git a/translation_node/src/ros_translations.h b/translation_node/src/ros_translations.h
deleted file mode 100644
index 523e44d8..00000000
--- a/translation_node/src/ros_translations.h
+++ /dev/null
@@ -1,84 +0,0 @@
-#pragma once
-
-#include
-#include "translations.h"
-
-class RosTranslationVersion {
-public:
-private:
- rclcpp::SubscriptionBase::SharedPtr _subscriber;
- rclcpp::PublisherBase::SharedPtr _publisher;
-};
-
-class RosTranslationForTopic {
-public:
- struct VersionEntry {
- VersionEntry() = default;
- explicit VersionEntry(const DirectTranslationData::Version& version)
- : version(version.version), message_buffer(version.message_buffer),
- max_serialized_message_size(version.max_serialized_message_size),
- subscription_factory(version.subscription_factory), publication_factory(version.publication_factory) {}
-
- MessageVersionType version; ///< corresponds to the 'newer' version in the translation
- DirectTranslationCB translation_cb_from_older;
- DirectTranslationCB translation_cb_to_older;
- std::shared_ptr message_buffer;
- size_t max_serialized_message_size{};
-
- SubscriptionFactoryCB subscription_factory;
- PublicationFactoryCB publication_factory;
-
- // Keep track if there's currently a publisher/subscriber
- bool has_external_publisher{false};
- bool has_external_subscriber{false};
-
- rclcpp::SubscriptionBase::SharedPtr subscription;
- rclcpp::PublisherBase::SharedPtr publication;
- };
- explicit RosTranslationForTopic(rclcpp::Node& node, std::string topic_name, const TranslationForTopic& translation);
-
- const std::string& topicName() const { return _topic_name; }
- std::vector& versions() { return _versions; }
-
- void resetExternalSubPub() {
- for (auto& entry : _versions) {
- entry.has_external_publisher = false;
- entry.has_external_subscriber = false;
- }
- }
-
- void updateSubsAndPubs();
-
- bool getAndSetErrorPrinted() {
- const bool error_printed = _error_printed;
- _error_printed = true;
- return error_printed;
- }
-
-private:
- void onSubscriptionUpdated(unsigned version_index, void* data);
- void handleLargestTopic();
-
- rclcpp::Node& _node;
- const std::string _topic_name;
- std::vector _versions;
- bool _error_printed{false}; ///< ensure errors are printed only once
-
- std::shared_ptr _publisher_largest_topic;
-};
-
-class RosTranslations {
-public:
- struct TopicInfo {
- std::string topic_name; ///< fully qualified topic name (with namespace)
- int num_subscribers; ///< does not include this node's subscribers
- int num_publishers; ///< does not include this node's publishers
- };
-
- explicit RosTranslations(rclcpp::Node& node, const Translations& translations);
-
- void updateCurrentTopics(const std::vector& topics);
-private:
- rclcpp::Node& _node;
- std::unordered_map _topics;
-};
\ No newline at end of file
diff --git a/translation_node/src/translation_util.h b/translation_node/src/translation_util.h
index 5a7bcc89..3a896c58 100644
--- a/translation_node/src/translation_util.h
+++ b/translation_node/src/translation_util.h
@@ -1,3 +1,7 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#pragma once
#include "translations.h"
@@ -19,20 +23,23 @@ class RegisteredTranslations {
}
template
- void registerTranslation(const std::string& topic_name) {
- DirectTranslationData data{};
- data.older = getVersionForMessageType(topic_name);
- data.newer = getVersionForMessageType(topic_name);
+ void registerDirectTranslation(const std::string& topic_name) {
+ _translations.addTopic(getTopicForMessageType(topic_name));
+ _translations.addTopic(getTopicForMessageType(topic_name));
// Translation callbacks
- data.translation_cb_from_older = [](const void* older_msg, void* newer_msg) {
- T::fromOlder(*(const typename T::MessageOlder*)older_msg, *(typename T::MessageNewer*)newer_msg);
+ auto translation_cb_from_older = [](const std::vector& older_msg, std::vector& newer_msg) {
+ T::fromOlder(*(const typename T::MessageOlder*)older_msg[0].get(), *(typename T::MessageNewer*)newer_msg[0].get());
};
- data.translation_cb_to_older = [](const void* newer_msg, void* older_msg) {
- T::toOlder(*(const typename T::MessageNewer*)newer_msg, *(typename T::MessageOlder*)older_msg);
+ auto translation_cb_to_older = [](const std::vector& newer_msg, std::vector& older_msg) {
+ T::toOlder(*(const typename T::MessageNewer*)newer_msg[0].get(), *(typename T::MessageOlder*)older_msg[0].get());
};
-
- _translations.registerDirectTranslation(topic_name, std::move(data));
+ _translations.addTranslation({translation_cb_from_older,
+ {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}},
+ {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}});
+ _translations.addTranslation({translation_cb_to_older,
+ {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}},
+ {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}});
}
const Translations& translations() const { return _translations; }
@@ -41,19 +48,22 @@ class RegisteredTranslations {
RegisteredTranslations() = default;
template
- DirectTranslationData::Version getVersionForMessageType(const std::string& topic_name) {
- DirectTranslationData::Version ret{};
+ Topic getTopicForMessageType(const std::string& topic_name) {
+ Topic ret{};
+ ret.topic_name = topic_name;
ret.version = RosMessageType::MESSAGE_VERSION;
- ret.message_buffer = std::static_pointer_cast(std::make_shared());
+ auto message_buffer = std::make_shared();
+ ret.message_buffer = std::static_pointer_cast(message_buffer);
// Subscription/Publication factory methods
const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.version);
- ret.subscription_factory = [topic_name_versioned](rclcpp::Node& node,
- const std::function& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr {
+ ret.subscription_factory = [topic_name_versioned, message_buffer](rclcpp::Node& node,
+ const std::function& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr {
return std::dynamic_pointer_cast(
node.create_subscription(topic_name_versioned, rclcpp::QoS(1).best_effort(),
- [on_topic_cb=on_topic_cb](typename RosMessageType::UniquePtr msg) -> void {
- on_topic_cb(msg.get());
+ [on_topic_cb=on_topic_cb, message_buffer](typename RosMessageType::UniquePtr msg) -> void {
+ *message_buffer = *msg;
+ on_topic_cb();
}));
};
ret.publication_factory = [topic_name_versioned](rclcpp::Node& node) -> rclcpp::PublisherBase::SharedPtr {
@@ -76,15 +86,15 @@ class RegisteredTranslations {
};
template
-class RegistrationHelper {
+class RegistrationHelperDirect {
public:
- explicit RegistrationHelper(const std::string& topic_name) {
- RegisteredTranslations::instance().registerTranslation(topic_name);
+ explicit RegistrationHelperDirect(const std::string& topic_name) {
+ RegisteredTranslations::instance().registerDirectTranslation(topic_name);
}
- RegistrationHelper(RegistrationHelper const&) = delete;
- void operator=(RegistrationHelper const&) = delete;
+ RegistrationHelperDirect(RegistrationHelperDirect const&) = delete;
+ void operator=(RegistrationHelperDirect const&) = delete;
private:
};
#define REGISTER_MESSAGE_TRANSLATION_DIRECT(topic_name, class_name) \
- static RegistrationHelper class_name##_registration(topic_name);
+ static RegistrationHelperDirect class_name##_registration(topic_name);
diff --git a/translation_node/src/translations.cpp b/translation_node/src/translations.cpp
index 01a3552a..0cca2d5a 100644
--- a/translation_node/src/translations.cpp
+++ b/translation_node/src/translations.cpp
@@ -1,14 +1,6 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#include "translations.h"
-void TranslationForTopic::registerVersion(DirectTranslationData data) {
- _direct_translations.emplace_back(std::move(data));
-}
-
-void Translations::registerDirectTranslation(const std::string &topic_name, DirectTranslationData data) {
- auto iter = _topic_translations.find(topic_name);
- if (iter == _topic_translations.end()) {
- auto [iter_inserted, _] = _topic_translations.emplace(topic_name, TranslationForTopic(topic_name));
- iter = iter_inserted;
- }
- iter->second.registerVersion(std::move(data));
-}
diff --git a/translation_node/src/translations.h b/translation_node/src/translations.h
index 73ada193..5b5d0497 100644
--- a/translation_node/src/translations.h
+++ b/translation_node/src/translations.h
@@ -1,3 +1,7 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#pragma once
#include
@@ -9,53 +13,42 @@
#include
#include "util.h"
+#include "graph.h"
#include
-using DirectTranslationCB = std::function;
-using SubscriptionFactoryCB = std::function& on_topic_cb)>;
+using TranslationCB = std::function&, std::vector&)>;
+using SubscriptionFactoryCB = std::function& on_topic_cb)>;
using PublicationFactoryCB = std::function;
-struct DirectTranslationData {
- struct Version {
- MessageVersionType version;
- std::shared_ptr message_buffer;
- SubscriptionFactoryCB subscription_factory;
- PublicationFactoryCB publication_factory;
- size_t max_serialized_message_size{};
- };
+struct Topic {
+ std::string topic_name;
+ MessageVersionType version{};
- Version older;
- Version newer;
+ SubscriptionFactoryCB subscription_factory;
+ PublicationFactoryCB publication_factory;
- DirectTranslationCB translation_cb_from_older;
- DirectTranslationCB translation_cb_to_older;
+ std::shared_ptr message_buffer;
+ size_t max_serialized_message_size{};
};
-class TranslationForTopic {
-public:
- explicit TranslationForTopic(std::string topic_name="") : _topic_name(std::move(topic_name)) {}
-
- void registerVersion(DirectTranslationData data);
-
- const std::string& topicName() const { return _topic_name; }
- const std::vector& directTranslations() const { return _direct_translations; };
-
-private:
- const std::string _topic_name;
- std::vector _direct_translations;
+struct Translation {
+ TranslationCB cb;
+ std::vector inputs;
+ std::vector outputs;
};
class Translations {
public:
-
Translations() = default;
- void registerDirectTranslation(const std::string& topic_name, DirectTranslationData data);
-
- const std::unordered_map& topicTranslations() const { return _topic_translations; }
+ void addTopic(Topic topic) { _topics.push_back(std::move(topic)); }
+ void addTranslation(Translation translation) { _translations.push_back(std::move(translation)); }
+ const std::vector& topics() const { return _topics; }
+ const std::vector& translations() const { return _translations; }
private:
- std::unordered_map _topic_translations;
+ std::vector _topics;
+ std::vector _translations;
};
diff --git a/translation_node/src/util.h b/translation_node/src/util.h
index a4f61da9..02aad1d3 100644
--- a/translation_node/src/util.h
+++ b/translation_node/src/util.h
@@ -1,3 +1,7 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#pragma once
#include
@@ -30,3 +34,19 @@ static inline std::pair getNonVersionedTopicNam
}
return std::make_pair(non_versioned_topic_name, std::stol(version));
}
+
+/**
+ * Get the full topic name, including namespace from a topic name.
+ * namespace_name should be set to Node::get_effective_namespace()
+ */
+static inline std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) {
+ std::string full_topic_name = topic_name;
+ if (!full_topic_name.empty() && full_topic_name[0] != '/') {
+ if (namespace_name.empty() || namespace_name.back() != '/') {
+ full_topic_name = '/' + full_topic_name;
+ }
+ full_topic_name = namespace_name + full_topic_name;
+ }
+ return full_topic_name;
+}
+
diff --git a/translation_node/src/vehicle_attitude_v2.h b/translation_node/src/vehicle_attitude_v2.h
index 4675d7ce..eefc2d72 100644
--- a/translation_node/src/vehicle_attitude_v2.h
+++ b/translation_node/src/vehicle_attitude_v2.h
@@ -1,3 +1,7 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#pragma once
// Translate VehicleAttitude v1 <--> v2
diff --git a/translation_node/src/vehicle_attitude_v3.h b/translation_node/src/vehicle_attitude_v3.h
index 40c7d811..53fd06d9 100644
--- a/translation_node/src/vehicle_attitude_v3.h
+++ b/translation_node/src/vehicle_attitude_v3.h
@@ -1,3 +1,7 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
#pragma once
// Translate VehicleAttitude v2 <--> v3
diff --git a/translation_node/test/graph.cpp b/translation_node/test/graph.cpp
new file mode 100644
index 00000000..7a7f25ef
--- /dev/null
+++ b/translation_node/test/graph.cpp
@@ -0,0 +1,355 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
+
+#include
+#include
+
+
+TEST(graph, basic)
+{
+ struct NodeData {
+ bool iterated{false};
+ bool translated{false};
+ };
+ Graph graph;
+
+ const int32_t message1_value = 3;
+ const int32_t offset = 4;
+
+ // Add 2 nodes
+ const MessageIdentifier id1{"topic_name", 1};
+ auto buffer1 = std::make_shared();
+ *buffer1 = message1_value;
+ EXPECT_TRUE(graph.addNodeIfNotExists(id1, {}, buffer1));
+ EXPECT_FALSE(graph.addNodeIfNotExists(id1, {}, std::make_shared()));
+ const MessageIdentifier id2{"topic_name", 4};
+ auto buffer2 = std::make_shared();
+ *buffer2 = 773;
+ EXPECT_TRUE(graph.addNodeIfNotExists(id2, {}, buffer2));
+
+ // Search nodes
+ EXPECT_TRUE(graph.findNode(id1).has_value());
+ EXPECT_TRUE(graph.findNode(id2).has_value());
+
+ // Add 1 translation
+ auto translation_cb = [&offset](const std::vector& a, std::vector& b) {
+ auto a_value = static_cast(a[0].get());
+ auto b_value = static_cast(b[0].get());
+ *b_value = *a_value + offset;
+ };
+ graph.addTranslation(translation_cb, {id1}, {id2});
+
+ // Iteration from id1 must reach id2
+ auto node1 = graph.findNode(id1).value();
+ auto node2 = graph.findNode(id2).value();
+ auto iterate_cb = [](const Graph::MessageNodePtr& node) {
+ node->data().iterated = true;
+ };
+ graph.iterateBFS(node1, iterate_cb);
+ EXPECT_TRUE(node1->data().iterated);
+ EXPECT_TRUE(node2->data().iterated);
+ node1->data().iterated = false;
+ node2->data().iterated = false;
+
+ // Iteration from id2 must not reach id1
+ graph.iterateBFS(node2, iterate_cb);
+ EXPECT_FALSE(node1->data().iterated);
+ EXPECT_TRUE(node2->data().iterated);
+
+ // Test translation
+ graph.translate(node1, [](auto&& node) { return true; },
+ [](auto&& node) {
+ node->data().translated = true;
+ });
+ EXPECT_TRUE(node1->data().translated);
+ EXPECT_TRUE(node2->data().translated);
+ EXPECT_EQ(*buffer1, message1_value);
+ EXPECT_EQ(*buffer2, message1_value + offset);
+}
+
+
+TEST(graph, multi_path)
+{
+ // Multiple paths with cycles
+ struct NodeData {
+ unsigned iterated_idx{0};
+ bool translated{false};
+ };
+ Graph graph;
+
+ static constexpr unsigned num_nodes = 6;
+ std::array ids{{
+ {"topic_name", 1},
+ {"topic_name", 2},
+ {"topic_name", 3},
+ {"topic_name", 4},
+ {"topic_name", 5},
+ {"topic_name", 6},
+ }};
+
+ std::array, num_nodes> buffers{{
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ }};
+
+ // Nodes
+ for (unsigned i = 0; i < num_nodes; ++i) {
+ EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i]));
+ }
+
+ // Translations
+ std::bitset<32> translated;
+
+ auto get_translation_cb = [&translated](unsigned bit) {
+ auto translation_cb = [&translated, bit](const std::vector &a, std::vector &b) {
+ auto a_value = static_cast(a[0].get());
+ auto b_value = static_cast(b[0].get());
+ *b_value = *a_value | (1 << bit);
+ translated.set(bit);
+ };
+ return translation_cb;
+ };
+
+ // Graph:
+ // ___ 2 -- 3 -- 4
+ // | |
+ // 1 _______|
+ // |
+ // 5
+ // |
+ // 6
+
+ unsigned next_bit = 0;
+ // Connect each node to the previous and next, except the last 3
+ for (unsigned i=0; i < num_nodes - 3; ++i) {
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[i]}, {ids[i+1]});
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[i+1]}, {ids[i]});
+ }
+
+ // Connect the first to the 3rd as well
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[2]});
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[2]}, {ids[0]});
+
+ // Connect the second last to the first one
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[num_nodes-2]});
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[0]});
+
+ // Connect the second last to the last one
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-1]}, {ids[num_nodes-2]});
+ graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[num_nodes-1]});
+
+ unsigned iteration_idx = 1;
+ graph.iterateBFS(graph.findNode(ids[0]).value(), [&iteration_idx](const Graph::MessageNodePtr& node) {
+ assert(node->data().iterated_idx == 0);
+ node->data().iterated_idx = iteration_idx++;
+ });
+
+ EXPECT_EQ(graph.findNode(ids[0]).value()->data().iterated_idx, 1);
+ // We're a bit stricter than we would have to be: ids[1,2,4] would be allowed to have any of the values (2,3,4)
+ EXPECT_EQ(graph.findNode(ids[1]).value()->data().iterated_idx, 2);
+ EXPECT_EQ(graph.findNode(ids[2]).value()->data().iterated_idx, 3);
+ EXPECT_EQ(graph.findNode(ids[4]).value()->data().iterated_idx, 4);
+ EXPECT_EQ(graph.findNode(ids[3]).value()->data().iterated_idx, 5);
+
+
+ // Translation
+ graph.translate(graph.findNode(ids[0]).value(),
+ [&](auto&& node) {
+ // Skip the last 2 nodes
+ return node.get() != graph.findNode(ids[num_nodes-1]).value().get() &&
+ node.get() != graph.findNode(ids[num_nodes-2]).value().get();
+ },
+ [](auto&& node) {
+ node->data().translated = true;
+ });
+
+ // Last 2 nodes should not be translated, the rest should be
+ for (unsigned i =0; i < num_nodes-2; ++i) {
+ EXPECT_EQ(graph.findNode(ids[i]).value()->data().translated, true);
+ }
+ EXPECT_EQ(graph.findNode(ids[num_nodes-2]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[num_nodes-1]).value()->data().translated, false);
+
+ // Ensure the correct edges were used for translations
+ EXPECT_EQ("00000000000000000000000001010001", translated.to_string());
+
+ // Ensure correct translation path taken for each node (which is stored in the buffers),
+ // and translation callback got called
+ EXPECT_EQ(*buffers[0], 0);
+ EXPECT_EQ(*buffers[1], 0b1);
+ EXPECT_EQ(*buffers[2], 0b1000000);
+ EXPECT_EQ(*buffers[3], 0b1010000);
+ EXPECT_EQ(*buffers[4], 0);
+ EXPECT_EQ(*buffers[5], 0);
+
+ for (unsigned i=0; i < num_nodes; ++i) {
+ printf("node[%i]: translated: %i, buffer: %i\n", i, graph.findNode(ids[i]).value()->data().translated,
+ *buffers[i]);
+ }
+}
+
+TEST(graph, multi_links) {
+ // Multiple topics (merging / splitting)
+ struct NodeData {
+ bool translated{false};
+ };
+ Graph graph;
+
+ static constexpr unsigned num_nodes = 6;
+ std::array ids{{
+ {"topic1", 1},
+ {"topic2", 1},
+ {"topic1", 2},
+ {"topic3", 1},
+ {"topic4", 1},
+ {"topic1", 3},
+ }};
+
+ std::array, num_nodes> buffers{{
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ std::make_shared(),
+ }};
+
+ // Nodes
+ for (unsigned i = 0; i < num_nodes; ++i) {
+ EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i]));
+ }
+
+
+ // Graph
+ // ___
+ // 1 - | | ---
+ // | | - 3 - | | - 6
+ // 2 - | | ---
+ // | ---
+ // | ___
+ // --- | | - 4
+ // | | - 5
+ // ---
+
+ // Translations
+ auto translation_cb_merge = [](const std::vector &a, std::vector &b) {
+ assert(a.size() == 2);
+ assert(b.size() == 1);
+ auto a_value1 = static_cast(a[0].get());
+ auto a_value2 = static_cast(a[1].get());
+ auto b_value = static_cast(b[0].get());
+ *b_value = *a_value1 | *a_value2;
+ };
+ auto translation_cb_split = [](const std::vector &a, std::vector &b) {
+ assert(a.size() == 1);
+ assert(b.size() == 2);
+ auto a_value = static_cast(a[0].get());
+ auto b_value1 = static_cast(b[0].get());
+ auto b_value2 = static_cast(b[1].get());
+ *b_value1 = *a_value & 0x0000ffffu;
+ *b_value2 = *a_value & 0xffff0000u;
+ };
+ auto translation_cb_direct = [](const std::vector &a, std::vector &b) {
+ assert(a.size() == 1);
+ assert(b.size() == 1);
+ auto a_value = static_cast(a[0].get());
+ auto b_value = static_cast(b[0].get());
+ *b_value = *a_value;
+ };
+
+ auto addTranslation = [&](const std::vector& inputs, const std::vector& outputs) {
+ assert(inputs.size() <= 2);
+ assert(outputs.size() <= 2);
+ if (inputs.size() == 1) {
+ if (outputs.size() == 1) {
+ graph.addTranslation(translation_cb_direct, inputs, outputs);
+ graph.addTranslation(translation_cb_direct, outputs, inputs);
+ } else {
+ graph.addTranslation(translation_cb_split, inputs, outputs);
+ graph.addTranslation(translation_cb_merge, outputs, inputs);
+ }
+ } else {
+ assert(outputs.size() == 1);
+ graph.addTranslation(translation_cb_merge, inputs, outputs);
+ graph.addTranslation(translation_cb_split, outputs, inputs);
+ }
+ };
+ addTranslation({ids[0], ids[1]}, {ids[2]});
+ addTranslation({ids[1]}, {ids[3], ids[4]});
+ addTranslation({ids[2]}, {ids[5]});
+
+ auto translate_node = [&](const MessageIdentifier& id) {
+ graph.translate(graph.findNode(id).value(),
+ [&](auto&& node) { return true; },
+ [](auto&& node) {
+ node->data().translated = true;
+ });
+
+ };
+ auto reset_translated = [&]() {
+ for (const auto& id : ids) {
+ graph.findNode(id).value()->data().translated = false;
+ }
+ };
+
+ // Updating node 2 should trigger an output for nodes 4+5 (splitting)
+ *buffers[0] = 0xa00000b0;
+ *buffers[1] = 0x0f00000f;
+ translate_node(ids[1]);
+ EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false);
+ EXPECT_EQ(*buffers[3], 0x0000000f);
+ EXPECT_EQ(*buffers[4], 0x0f000000);
+
+ reset_translated();
+
+ // Now updating node 1 should update nodes 3+6 (merging, both inputs available now)
+ translate_node(ids[0]);
+ EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true);
+ EXPECT_EQ(*buffers[2], 0xaf0000bf);
+ EXPECT_EQ(*buffers[5], 0xaf0000bf);
+
+ reset_translated();
+
+ // Another update must not trigger any other updates
+ translate_node(ids[0]);
+ EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false);
+ EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false);
+
+ reset_translated();
+
+ // Backwards: updating node 6 should trigger updates for 1+2, but also 4+5
+ *buffers[5] = 0xc00000d0;
+ translate_node(ids[5]);
+ EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true);
+ EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true);
+ EXPECT_EQ(*buffers[0], 0x000000d0);
+ EXPECT_EQ(*buffers[1], 0xc0000000);
+ EXPECT_EQ(*buffers[2], 0xc00000d0);
+ EXPECT_EQ(*buffers[3], 0);
+ EXPECT_EQ(*buffers[4], 0xc0000000);
+ EXPECT_EQ(*buffers[5], 0xc00000d0);
+}
\ No newline at end of file
diff --git a/translation_node/test/main.cpp b/translation_node/test/main.cpp
new file mode 100644
index 00000000..9d7e484c
--- /dev/null
+++ b/translation_node/test/main.cpp
@@ -0,0 +1,16 @@
+/****************************************************************************
+ * Copyright (c) 2024 PX4 Development Team.
+ * SPDX-License-Identifier: BSD-3-Clause
+ ****************************************************************************/
+
+#include
+#include
+
+int main(int argc, char ** argv)
+{
+ rclcpp::init(argc, argv);
+ testing::InitGoogleTest(&argc, argv);
+ const int ret = RUN_ALL_TESTS();
+ rclcpp::shutdown();
+ return ret;
+}