diff --git a/translation_node/CMakeLists.txt b/translation_node/CMakeLists.txt index 3d8b06eb..9b7868a1 100644 --- a/translation_node/CMakeLists.txt +++ b/translation_node/CMakeLists.txt @@ -11,14 +11,35 @@ find_package(rclcpp REQUIRED) find_package(px4_msgs REQUIRED) find_package(px4_msgs_old REQUIRED) -add_executable(translation_node - src/main.cpp - src/ros_translations.cpp +add_library(${PROJECT_NAME}_lib + src/pub_sub_graph.cpp src/translations.cpp ) -ament_target_dependencies(translation_node rclcpp px4_msgs px4_msgs_old) +ament_target_dependencies(${PROJECT_NAME}_lib rclcpp px4_msgs px4_msgs_old) +add_executable(${PROJECT_NAME} + src/main.cpp +) +target_link_libraries(${PROJECT_NAME} ${PROJECT_NAME}_lib) +ament_target_dependencies(${PROJECT_NAME} rclcpp px4_msgs px4_msgs_old) install(TARGETS - translation_node + ${PROJECT_NAME} DESTINATION lib/${PROJECT_NAME}) +if(BUILD_TESTING) + find_package(ament_lint_auto REQUIRED) + find_package(ament_cmake_gtest REQUIRED) + ament_lint_auto_find_test_dependencies() + + # Unit tests + ament_add_gtest(${PROJECT_NAME}_unit_tests + test/graph.cpp + test/main.cpp + ) + target_include_directories(${PROJECT_NAME}_unit_tests PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_link_libraries(${PROJECT_NAME}_unit_tests ${PROJECT_NAME}_lib) + ament_target_dependencies(${PROJECT_NAME}_unit_tests + rclcpp + ) +endif() + ament_package() diff --git a/translation_node/package.xml b/translation_node/package.xml index e212547b..a34688f3 100644 --- a/translation_node/package.xml +++ b/translation_node/package.xml @@ -11,6 +11,7 @@ ament_lint_auto ament_lint_common + ament_cmake_gtest rclcpp px4_msgs diff --git a/translation_node/src/graph.h b/translation_node/src/graph.h new file mode 100644 index 00000000..2e73841c --- /dev/null +++ b/translation_node/src/graph.h @@ -0,0 +1,317 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include "util.h" +#include +#include +#include +#include +#include +#include +#include +#include + +// This implements a directed graph with potential cycles used for translation. +// There are 2 types of nodes: messages (e.g. publication/subscription endpoints) and +// translations. Translation nodes are always in between message nodes, and can have N input messages +// and M output messages. + +struct MessageIdentifier { + std::string topic_name; + MessageVersionType version; + + bool operator==(const MessageIdentifier& other) const { + return topic_name == other.topic_name && version == other.version; + } + bool operator!=(const MessageIdentifier& other) const { + return !(*this == other); + } +}; + +template<> +struct std::hash +{ + std::size_t operator()(const MessageIdentifier& s) const noexcept + { + std::size_t h1 = std::hash{}(s.topic_name); + std::size_t h2 = std::hash{}(s.version); + return h1 ^ (h2 << 1); + } +}; + + +using MessageBuffer = std::shared_ptr; + +template +class MessageNode; +template +class Graph; + +template +using MessageNodePtrT = std::shared_ptr>; + +template +class TranslationNode { +public: + using TranslationCB = std::function&, std::vector&)>; + + TranslationNode(std::vector> inputs, + std::vector> outputs, + TranslationCB translation_db) + : _inputs(std::move(inputs)), _outputs(std::move(outputs)), _translation_cb(std::move(translation_db)) { + assert(_inputs.size() <= kMaxNumInputs); + + _input_buffers.resize(_inputs.size()); + for (unsigned i = 0; i < _inputs.size(); ++i) { + _input_buffers[i] = _inputs[i]->buffer(); + } + + _output_buffers.resize(_outputs.size()); + for (unsigned i = 0; i < _outputs.size(); ++i) { + _output_buffers[i] = _outputs[i]->buffer(); + } + } + + void setInputReady(unsigned index) { + _inputs_ready.set(index); + } + + bool translate() { + if (_inputs_ready.count() == _input_buffers.size()) { + _translation_cb(_input_buffers, _output_buffers); + _inputs_ready.reset(); + return true; + } + return false; + } + + const std::vector>& inputs() const { return _inputs; } + const std::vector>& outputs() const { return _outputs; } + +private: + static constexpr int kMaxNumInputs = 32; + + const std::vector> _inputs; + std::vector _input_buffers; ///< Cached buffers from _inputs.buffer() + const std::vector> _outputs; + std::vector _output_buffers; + const TranslationCB _translation_cb; + + std::bitset _inputs_ready; +}; + +template +using TranslationNodePtrT = std::shared_ptr>; + + +template +class MessageNode { +public: + + explicit MessageNode(NodeData node_data, size_t index, MessageBuffer message_buffer) + : _buffer(std::move(message_buffer)), _data(std::move(node_data)), _index(index) {} + + MessageBuffer& buffer() { return _buffer; } + + void addTranslationInput(TranslationNodePtrT node, unsigned input_index) { + _translations.push_back(Translation{std::move(node), input_index}); + } + + NodeData& data() { return _data; } + + void resetNodes() { + _translations.clear(); + } + +private: + struct Translation { + TranslationNodePtrT node; ///< Counterpart to the TranslationNode::_inputs + unsigned input_index; ///< Index into the TranslationNode::_inputs + }; + MessageBuffer _buffer; + std::vector _translations; + + NodeData _data; + + const size_t _index; + MessageNode* _iterating_previous{nullptr}; + bool _want_translation{false}; + + friend class Graph; +}; + +template +class Graph { +public: + using MessageNodePtr = MessageNodePtrT; + + ~Graph() { + // Explicitly reset the nodes array to break up potential cycles and prevent memory leaks + for (auto& [id, node] : _nodes) { + node->resetNodes(); + } + } + + /** + * @brief Add a message node if it does not exist already + */ + bool addNodeIfNotExists(const IdType& id, NodeData node_data, const MessageBuffer& message_buffer) { + if (_nodes.find(id) != _nodes.end()) { + return false; + } + // Node that we cannot remove nodes due to using the index as an array index + const size_t index = _nodes.size(); + _nodes.insert({id, std::make_shared>(std::move(node_data), index, message_buffer)}); + return true; + } + + /** + * @brief Add a translation edge with N inputs and M output nodes. All nodes must already exist. + */ + void addTranslation(const typename TranslationNode::TranslationCB& translation_cb, + const std::vector& inputs, const std::vector& outputs) { + auto init = [this](const std::vector& from, std::vector>& to) { + for (unsigned i=0; i < from.size(); ++i) { + auto node_iter = _nodes.find(from[i]); + assert(node_iter != _nodes.end()); + to[i] = node_iter->second; + } + }; + std::vector> input_nodes(inputs.size()); + init(inputs, input_nodes); + std::vector> output_nodes(outputs.size()); + init(outputs, output_nodes); + + auto translation_node = std::make_shared>(std::move(input_nodes), std::move(output_nodes), translation_cb); + for (unsigned i=0; i < translation_node->inputs().size(); ++i) { + translation_node->inputs()[i]->addTranslationInput(translation_node, i); + } + } + + + /** + * @brief Translate a message node in the graph. + * + * This function performs a two-pass translation of a message node in the graph. + * First, it finds the required nodes that need the translation results, and then + * it runs the translation on these nodes to prevent unnecessary message conversions. + * + * @param node The message node to translate. + * @param node_requires_translation_result A callback function that determines whether a node requires the translation result. + * @param on_translated A callback function that is called for translated nodes (with an updated message buffer). + */ + void translate(const MessageNodePtr& node, const std::function& node_requires_translation_result, + const std::function& on_translated) { + // Do translation in 2 passes: first, find the required nodes that require the translation results, + // then run the translation on these nodes to prevent unnecessary message conversions + // (the assumption here is that conversions are more expensive than iterating the graph) + prepareTranslation(node, node_requires_translation_result); + runTranslation(node, on_translated); + } + + std::optional findNode(const IdType& id) const { + auto iter = _nodes.find(id); + if (iter == _nodes.end()) { + return std::nullopt; + } + return iter->second; + } + + void iterateNodes(const std::function& cb) const { + for (const auto& [id, node] : _nodes) { + cb(id, node); + } + } + + /** + * Iterate all reachable nodes from a given node using the BFS (shortest path) algorithm + */ + void iterateBFS(const MessageNodePtr& node, const std::function& cb) { + _node_visited.resize(_nodes.size()); + std::fill(_node_visited.begin(), _node_visited.end(), false); + + std::queue queue; + _node_visited[node->_index] = true; + node->_iterating_previous = nullptr; + queue.push(node); + cb(node); + + while (!queue.empty()) { + MessageNodePtr current = queue.front(); + queue.pop(); + for (auto& translation : current->_translations) { + for (auto& next_node : translation.node->outputs()) { + if (_node_visited[next_node->_index]) { + continue; + } + _node_visited[next_node->_index] = true; + next_node->_iterating_previous = current.get(); + queue.push(next_node); + + cb(next_node); + } + } + } + } + + +private: + void prepareTranslation(const MessageNodePtr& node, const std::function& node_requires_translation_result) { + iterateBFS(node, [&](const MessageNodePtr& node) { + if (node_requires_translation_result(node)) { + auto* previous_node = node.get(); + while (previous_node) { + previous_node->_want_translation = true; + previous_node = previous_node->_iterating_previous; + } + } + }); + } + + void runTranslation(const MessageNodePtr& node, const std::function& on_translated) { + _node_had_update.resize(_nodes.size()); + std::fill(_node_had_update.begin(), _node_had_update.end(), false); + _node_had_update[node->_index] = true; + + iterateBFS(node, [&](const MessageNodePtr& node) { + // If there was no update for this node, there's nothing to do (i.e. want_translation is false or the + // message buffer did not change) + if (!_node_had_update[node->_index]) { + return; + } + + on_translated(node); + + if (node->_want_translation) { + node->_want_translation = false; + for (auto &translation : node->_translations) { + // Skip translation if none of the output nodes has _want_translation set or + // if any of the nodes already had an update. + // This also prevents translating 'backwards' by one step (from where we came from) + bool want_translation = false; + bool had_update = false; + for (auto &next_node: translation.node->outputs()) { + want_translation |= next_node->_want_translation; + had_update |= _node_had_update[next_node->_index]; + } + if (!want_translation || had_update) { + continue; + } + translation.node->setInputReady(translation.input_index); + if (translation.node->translate()) { + for (auto &next_node: translation.node->outputs()) { + _node_had_update[next_node->_index] = true; + } + } + } + } + }); + } + + std::unordered_map _nodes; + std::vector _node_visited; ///< Cached, to avoid the need to re-allocate on each iteration + std::vector _node_had_update; ///< Cached, to avoid the need to re-allocate on each iteration +}; \ No newline at end of file diff --git a/translation_node/src/main.cpp b/translation_node/src/main.cpp index 6f2442c2..7a4d65d8 100644 --- a/translation_node/src/main.cpp +++ b/translation_node/src/main.cpp @@ -1,25 +1,29 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #include #include #include "vehicle_attitude_v2.h" #include "vehicle_attitude_v3.h" -#include "ros_translations.h" +#include "pub_sub_graph.h" using namespace std::chrono_literals; -class TranslationNode : public rclcpp::Node +class RosTranslationNode : public rclcpp::Node { public: - TranslationNode() : Node("translation_node") + RosTranslationNode() : Node("translation_node") { - _ros_translations = std::make_unique(*this, RegisteredTranslations::instance().translations()); + _pub_sub_graph = std::make_unique(*this, RegisteredTranslations::instance().translations()); // Monitor subscriptions & publishers // TODO: event-based _node_update_timer = create_wall_timer(1s, [this](){ - std::vector topic_info; + std::vector topic_info; const auto topics = get_topic_names_and_types(); for (const auto& [topic_name, topic_types] : topics) { auto publishers = get_publishers_info_by_topic(topic_name); @@ -35,22 +39,23 @@ class TranslationNode : public rclcpp::Node } if (num_subscribers > 0 || num_publishers > 0) { - topic_info.emplace_back(RosTranslations::TopicInfo{topic_name, num_subscribers, num_publishers}); + topic_info.emplace_back(PubSubGraph::TopicInfo{topic_name, num_subscribers, num_publishers}); } } - _ros_translations->updateCurrentTopics(topic_info); + _pub_sub_graph->updateCurrentTopics(topic_info); + }); } private: - std::unique_ptr _ros_translations; + std::unique_ptr _pub_sub_graph; rclcpp::TimerBase::SharedPtr _node_update_timer; }; int main(int argc, char * argv[]) { rclcpp::init(argc, argv); - rclcpp::spin(std::make_shared()); + rclcpp::spin(std::make_shared()); rclcpp::shutdown(); return 0; } \ No newline at end of file diff --git a/translation_node/src/pub_sub_graph.cpp b/translation_node/src/pub_sub_graph.cpp new file mode 100644 index 00000000..7657b592 --- /dev/null +++ b/translation_node/src/pub_sub_graph.cpp @@ -0,0 +1,193 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#include "pub_sub_graph.h" +#include "util.h" + +PubSubGraph::PubSubGraph(rclcpp::Node &node, const Translations &translations) : _node(node) { + + std::unordered_map> known_versions; + + for (const auto& topic : translations.topics()) { + const std::string full_topic_name = getFullTopicName(_node.get_effective_namespace(), topic.topic_name); + _known_topics_warned.insert({full_topic_name, false}); + + const MessageIdentifier id{full_topic_name, topic.version}; + NodeDataPubSub node_data{topic.subscription_factory, topic.publication_factory, id, topic.max_serialized_message_size}; + _pub_sub_graph.addNodeIfNotExists(id, std::move(node_data), topic.message_buffer); + known_versions[full_topic_name].insert(id.version); + } + + for (const auto& translation : translations.translations()) { + std::vector inputs = translation.inputs; + for (auto& input : inputs) { + input.topic_name = getFullTopicName(_node.get_effective_namespace(), input.topic_name); + } + std::vector outputs = translation.outputs; + for (auto& output : outputs) { + output.topic_name = getFullTopicName(_node.get_effective_namespace(), output.topic_name); + } + _pub_sub_graph.addTranslation(translation.cb, inputs, outputs); + } + + printTopicInfo(known_versions); + handleLargestTopic(known_versions); +} + +void PubSubGraph::updateCurrentTopics(const std::vector &topics) { + + _pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph::MessageNodePtr& node) { + node->data().has_external_publisher = false; + node->data().has_external_subscriber = false; + }); + + for (const auto& info : topics) { + const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(info.topic_name); + auto maybe_node = _pub_sub_graph.findNode({non_versioned_topic_name, version}); + if (!maybe_node) { + auto known_topic_iter = _known_topics_warned.find(non_versioned_topic_name); + if (known_topic_iter != _known_topics_warned.end() && !known_topic_iter->second) { + RCLCPP_WARN(_node.get_logger(), "No translation available for version %i of topic %s", version, non_versioned_topic_name.c_str()); + known_topic_iter->second = true; + } + continue; + } + const auto& node = maybe_node.value(); + + if (info.num_publishers > 0) { + node->data().has_external_publisher = true; + } + if (info.num_subscribers > 0) { + node->data().has_external_subscriber = true; + } + } + + // Iterate connected graph segments + _pub_sub_graph.iterateNodes([this](const MessageIdentifier& type, const Graph::MessageNodePtr& node) { + if (node->data().visited) { + return; + } + node->data().visited = true; + + // Count the number of external subscribers and publishers for each connected graph + int num_publishers = 0; + int num_subscribers = 0; + int num_subscribers_without_publisher = 0; + + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + if (node->data().has_external_publisher) { + ++num_publishers; + } + if (node->data().has_external_subscriber) { + ++num_subscribers; + if (!node->data().has_external_publisher) { + ++num_subscribers_without_publisher; + } + } + }); + + // We need to instantiate publishers and subscribers if: + // - there are multiple publishers and at least 1 subscriber + // - there is 1 publisher and at least 1 subscriber on another node + // Note that in case of splitting or merging topics, this might create more entities than actually needed + const bool require_translation = (num_publishers >= 2 && num_subscribers >= 1) + || (num_publishers == 1 && num_subscribers_without_publisher >= 1); + if (require_translation) { + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + // Has subscriber(s)? + if (node->data().has_external_subscriber && !node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", node->data().topic_name.c_str(), node->data().version); + node->data().publication = node->data().publication_factory(_node); + } else if (!node->data().has_external_subscriber && node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", node->data().topic_name.c_str(), node->data().version); + node->data().publication.reset(); + } + // Has publisher(s)? + if (node->data().has_external_publisher && !node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", node->data().topic_name.c_str(), node->data().version); + node->data().subscription = node->data().subscription_factory(_node, [this, node_cpy=node]() { + onSubscriptionUpdate(node_cpy); + }); + } else if (!node->data().has_external_publisher && node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", node->data().topic_name.c_str(), node->data().version); + node->data().subscription.reset(); + } + }); + + } else { + // Reset any publishers or subscribers + _pub_sub_graph.iterateBFS(node, [&](const Graph::MessageNodePtr& node) { + if (node->data().publication) { + RCLCPP_INFO(_node.get_logger(), "Removing publisher for topic '%s', version: %i", + node->data().topic_name.c_str(), node->data().version); + node->data().publication.reset(); + } + if (node->data().subscription) { + RCLCPP_INFO(_node.get_logger(), "Removing subscriber for topic '%s', version: %i", + node->data().topic_name.c_str(), node->data().version); + node->data().subscription.reset(); + } + }); + } + }); + _pub_sub_graph.iterateNodes([](const MessageIdentifier& type, const Graph::MessageNodePtr& node) { + node->data().visited = false; + }); +} + +void PubSubGraph::onSubscriptionUpdate(const Graph::MessageNodePtr& node) { + _pub_sub_graph.translate( + node, + [](const Graph::MessageNodePtr& node) { + return node->data().publication != nullptr; + }, + [](const Graph::MessageNodePtr& node) { + if (node->data().publication != nullptr) { + rcl_publish(node->data().publication->get_publisher_handle().get(), + node->buffer().get(), nullptr); + } + }); + +} + +void PubSubGraph::printTopicInfo(const std::unordered_map>& known_versions) const { + // Print info about known versions + RCLCPP_INFO(_node.get_logger(), "Registered pub/sub topics and versions:"); + for (const auto& [topic_name, version_set] : known_versions) { + if (version_set.empty()) { + continue; + } + const std::string versions = std::accumulate(std::next(version_set.begin()), version_set.end(), + std::to_string(*version_set.begin()), // start with first element + [](std::string a, auto&& b) { + return std::move(a) + ", " + std::to_string(b); + }); + RCLCPP_INFO(_node.get_logger(), "- %s: %s", topic_name.c_str(), versions.c_str()); + } +} + + +void PubSubGraph::handleLargestTopic(const std::unordered_map> &known_versions) { + // FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given + // type. The information that is relevant for us is the maximum serialized message size. + // Since different versions can have different sizes, we need to ensure the first publication or subscription + // happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered. + // And the type must continue to be in use (so we cannot delete it) + for (const auto& [topic_name, versions] : known_versions) { + size_t max_serialized_message_size = 0; + const PublicationFactoryCB* publication_factory_for_max = nullptr; + for (auto version : versions) { + const auto& node = _pub_sub_graph.findNode(MessageIdentifier{topic_name, version}); + assert(node); + const auto& node_data = node.value()->data(); + if (node_data.max_serialized_message_size > max_serialized_message_size) { + max_serialized_message_size = node_data.max_serialized_message_size; + publication_factory_for_max = &node_data.publication_factory; + } + } + if (publication_factory_for_max) { + _largest_topic_publications.emplace_back((*publication_factory_for_max)(_node)); + } + } +} \ No newline at end of file diff --git a/translation_node/src/pub_sub_graph.h b/translation_node/src/pub_sub_graph.h new file mode 100644 index 00000000..8294baec --- /dev/null +++ b/translation_node/src/pub_sub_graph.h @@ -0,0 +1,58 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ +#pragma once + +#include +#include +#include "translations.h" +#include "translation_util.h" +#include "graph.h" + +class PubSubGraph { +public: + struct TopicInfo { + std::string topic_name; ///< fully qualified topic name (with namespace) + int num_subscribers; ///< does not include this node's subscribers + int num_publishers; ///< does not include this node's publishers + }; + + PubSubGraph(rclcpp::Node& node, const Translations& translations); + + void updateCurrentTopics(const std::vector& topics); + +private: + struct NodeDataPubSub { + explicit NodeDataPubSub(SubscriptionFactoryCB subscription_factory, PublicationFactoryCB publication_factory, + const MessageIdentifier& id, size_t max_serialized_message_size) + : subscription_factory(std::move(subscription_factory)), publication_factory(std::move(publication_factory)), + topic_name(id.topic_name), version(id.version), max_serialized_message_size(max_serialized_message_size) + { } + + const SubscriptionFactoryCB subscription_factory; + const PublicationFactoryCB publication_factory; + const std::string topic_name; + const MessageVersionType version; + const size_t max_serialized_message_size; + + // Keep track if there's currently a publisher/subscriber + bool has_external_publisher{false}; + bool has_external_subscriber{false}; + + rclcpp::SubscriptionBase::SharedPtr subscription; + rclcpp::PublisherBase::SharedPtr publication; + + bool visited{false}; + }; + + void onSubscriptionUpdate(const Graph::MessageNodePtr& node); + void printTopicInfo(const std::unordered_map>& known_versions) const; + void handleLargestTopic(const std::unordered_map>& known_versions); + + rclcpp::Node& _node; + Graph _pub_sub_graph; + std::unordered_map _known_topics_warned; + + std::vector _largest_topic_publications; +}; \ No newline at end of file diff --git a/translation_node/src/ros_translations.cpp b/translation_node/src/ros_translations.cpp deleted file mode 100644 index c5a055c7..00000000 --- a/translation_node/src/ros_translations.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include "ros_translations.h" - -static std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) { - std::string full_topic_name = topic_name; - if (!full_topic_name.empty() && full_topic_name[0] != '/') { - if (namespace_name.empty() || namespace_name.back() != '/') { - full_topic_name = '/' + full_topic_name; - } - full_topic_name = namespace_name + full_topic_name; - } - return full_topic_name; -} - -RosTranslationForTopic::RosTranslationForTopic(rclcpp::Node& node, std::string topic_name, const TranslationForTopic &translation) - : _node(node), _topic_name(std::move(topic_name)) { - assert(!translation.directTranslations().empty()); - // Find oldest version - auto version_iter = std::min_element(translation.directTranslations().begin(), translation.directTranslations().end(), - [](auto && a, auto&& b) { - return a.newer.version < b.newer.version; - }); - // Build version chain, use the newer element as the value for the node - // So we need to add the older element for the lowest version first - { - VersionEntry entry{version_iter->older}; - _versions.push_back(std::move(entry)); - } - - while (version_iter != translation.directTranslations().end()) { - // Add current version - VersionEntry entry{version_iter->newer}; - entry.translation_cb_to_older = version_iter->translation_cb_to_older; - entry.translation_cb_from_older = version_iter->translation_cb_from_older; - _versions.push_back(std::move(entry)); - - // Find next version (assumes there is no cycle in the definitions, otherwise there's an endless loop) - version_iter = std::find_if(translation.directTranslations().begin(), translation.directTranslations().end(), - [version_iter](auto&& a) { - return version_iter->newer.version == a.older.version; - }); - } - - if (_versions.size() != translation.directTranslations().size() + 1) { - // This means there is a gap in the versions - throw std::runtime_error(std::string("non-continuous versions for topic") + _topic_name); - } - - handleLargestTopic(); -} - -void RosTranslationForTopic::updateSubsAndPubs() { - const bool has_subscriber = std::any_of(_versions.begin(), _versions.end(), [](auto&& a) { - return a.has_external_subscriber; - }); - const bool has_publisher = std::any_of(_versions.begin(), _versions.end(), [](auto&& a) { - return a.has_external_publisher; - }); - if (has_subscriber && has_publisher) { - // TODO: do not add anything if there's only a subscriber & publisher for a specific version - - for (unsigned index = 0; index < _versions.size(); ++index) { - auto& version = _versions[index]; - // Has subscriber(s)? - if (version.has_external_subscriber && !version.publication) { - RCLCPP_INFO(_node.get_logger(), "Found subscriber for topic '%s', version: %i, adding publisher", _topic_name.c_str(), version.version); - version.publication = version.publication_factory(_node); - } else if (!version.has_external_subscriber && version.publication) { - RCLCPP_INFO(_node.get_logger(), "No subscribers for topic '%s', version: %i, removing publisher", _topic_name.c_str(), version.version); - version.publication.reset(); - } - // Has publisher(s)? - if (version.has_external_publisher && !version.subscription) { - RCLCPP_INFO(_node.get_logger(), "Found publisher for topic '%s', version: %i, adding subscriber", _topic_name.c_str(), version.version); - version.subscription = version.subscription_factory(_node, [this, index](void* data) { - onSubscriptionUpdated(index, data); - }); - } else if (!version.has_external_publisher && version.subscription) { - RCLCPP_INFO(_node.get_logger(), "No publishers for topic '%s', version: %i, removing subscriber", _topic_name.c_str(), version.version); - version.subscription.reset(); - } - - } - } else { - // Clear all - for (auto& version : _versions) { - version.publication.reset(); - version.subscription.reset(); - } - } -} - -void RosTranslationForTopic::onSubscriptionUpdated(unsigned version_index, void *data) { - const auto& entry = _versions[version_index]; - const auto lowest_publisher_iter = std::find_if(_versions.begin(), _versions.end(), [](auto&& a) { return a.publication != nullptr; }); - // Convert to lower versions - if (lowest_publisher_iter != _versions.end()) { - const unsigned lowest_index = std::distance(_versions.begin(), lowest_publisher_iter); - void* current_data = data; - for (unsigned index = version_index; index > lowest_index; --index) { - // Convert message - _versions[index].translation_cb_to_older(current_data, _versions[index-1].message_buffer.get()); - current_data = _versions[index-1].message_buffer.get(); - // Publish if there is a publisher - if (_versions[index-1].publication) { - rcl_publish(_versions[index-1].publication->get_publisher_handle().get(), current_data, nullptr); - } - } - } - // Convert to higher versions - const auto highest_publisher_iter = std::find_if(_versions.rbegin(), _versions.rend(), [](auto&& a) { return a.publication != nullptr; }); - if (highest_publisher_iter != _versions.rend()) { - const unsigned highest_index = std::distance(highest_publisher_iter, _versions.rend()) - 1; - void* current_data = data; - for (unsigned index = version_index; index < highest_index; ++index) { - // Convert message - _versions[index + 1].translation_cb_from_older(current_data, _versions[index+1].message_buffer.get()); - current_data = _versions[index+1].message_buffer.get(); - // Publish if there is a publisher - if (_versions[index+1].publication) { - rcl_publish(_versions[index+1].publication->get_publisher_handle().get(), current_data, nullptr); - } - } - } -} - -void RosTranslationForTopic::handleLargestTopic() { - // FastDDS caches some type information per DDS participant when first creating a publisher or subscriber for a given - // type. The information that is relevant for us is the maximum serialized message size. - // Since different versions can have different sizes, we need to ensure the first publication or subscription - // happens with the version of the largest size. Otherwise, an out-of-memory exception can be triggered. - // And the type must continue to be in use (so we cannot delete it) - auto version_iter = std::max_element(_versions.begin(), _versions.end(), - [](auto && a, auto&& b) { - return a.max_serialized_message_size < b.max_serialized_message_size; - }); - _publisher_largest_topic = version_iter->publication_factory(_node); -} - -RosTranslations::RosTranslations(rclcpp::Node &node, const Translations &translations) - : _node(node) { - - for (const auto& [topic_name, translation]: translations.topicTranslations()) { - const std::string full_topic_name = getFullTopicName(node.get_effective_namespace(), topic_name); - auto [iter, _] = _topics.emplace(full_topic_name, RosTranslationForTopic{_node, full_topic_name, translation}); - - // Print versions info - const std::string versions = std::accumulate(std::next(iter->second.versions().begin()), iter->second.versions().end(), - std::to_string(iter->second.versions()[0].version), // start with first element - [](std::string a, auto&& b) { - return std::move(a) + ", " + std::to_string(b.version); - }); - RCLCPP_INFO(_node.get_logger(), "Versions for topic '%s': %s", iter->second.topicName().c_str(), versions.c_str()); - } -} - -void RosTranslations::updateCurrentTopics(const std::vector &topics) { - for (auto& [topic_name, translation] : _topics) { - translation.resetExternalSubPub(); - } - for (const auto& topic_info : topics) { - const auto [non_versioned_topic_name, version] = getNonVersionedTopicName(topic_info.topic_name); - auto iter = _topics.find(non_versioned_topic_name); - if (iter == _topics.end()) { - continue; - } - // It's a topic we're interested in, find the version - bool found_version = false; - for (auto& entry : iter->second.versions()) { - if (entry.version == version) { - found_version = true; - if (topic_info.num_publishers > 0) { - entry.has_external_publisher = true; - } - if (topic_info.num_subscribers > 0) { - entry.has_external_subscriber = true; - } - } - } - if (!found_version && !iter->second.getAndSetErrorPrinted()) { - RCLCPP_WARN(_node.get_logger(), "Unsupported version for topic '%s': %i", non_versioned_topic_name.c_str(), version); - } - } - - // Now update the subscriptions / publishers depending on what we found - for (auto& [topic_name, translation] : _topics) { - translation.updateSubsAndPubs(); - } -} - diff --git a/translation_node/src/ros_translations.h b/translation_node/src/ros_translations.h deleted file mode 100644 index 523e44d8..00000000 --- a/translation_node/src/ros_translations.h +++ /dev/null @@ -1,84 +0,0 @@ -#pragma once - -#include -#include "translations.h" - -class RosTranslationVersion { -public: -private: - rclcpp::SubscriptionBase::SharedPtr _subscriber; - rclcpp::PublisherBase::SharedPtr _publisher; -}; - -class RosTranslationForTopic { -public: - struct VersionEntry { - VersionEntry() = default; - explicit VersionEntry(const DirectTranslationData::Version& version) - : version(version.version), message_buffer(version.message_buffer), - max_serialized_message_size(version.max_serialized_message_size), - subscription_factory(version.subscription_factory), publication_factory(version.publication_factory) {} - - MessageVersionType version; ///< corresponds to the 'newer' version in the translation - DirectTranslationCB translation_cb_from_older; - DirectTranslationCB translation_cb_to_older; - std::shared_ptr message_buffer; - size_t max_serialized_message_size{}; - - SubscriptionFactoryCB subscription_factory; - PublicationFactoryCB publication_factory; - - // Keep track if there's currently a publisher/subscriber - bool has_external_publisher{false}; - bool has_external_subscriber{false}; - - rclcpp::SubscriptionBase::SharedPtr subscription; - rclcpp::PublisherBase::SharedPtr publication; - }; - explicit RosTranslationForTopic(rclcpp::Node& node, std::string topic_name, const TranslationForTopic& translation); - - const std::string& topicName() const { return _topic_name; } - std::vector& versions() { return _versions; } - - void resetExternalSubPub() { - for (auto& entry : _versions) { - entry.has_external_publisher = false; - entry.has_external_subscriber = false; - } - } - - void updateSubsAndPubs(); - - bool getAndSetErrorPrinted() { - const bool error_printed = _error_printed; - _error_printed = true; - return error_printed; - } - -private: - void onSubscriptionUpdated(unsigned version_index, void* data); - void handleLargestTopic(); - - rclcpp::Node& _node; - const std::string _topic_name; - std::vector _versions; - bool _error_printed{false}; ///< ensure errors are printed only once - - std::shared_ptr _publisher_largest_topic; -}; - -class RosTranslations { -public: - struct TopicInfo { - std::string topic_name; ///< fully qualified topic name (with namespace) - int num_subscribers; ///< does not include this node's subscribers - int num_publishers; ///< does not include this node's publishers - }; - - explicit RosTranslations(rclcpp::Node& node, const Translations& translations); - - void updateCurrentTopics(const std::vector& topics); -private: - rclcpp::Node& _node; - std::unordered_map _topics; -}; \ No newline at end of file diff --git a/translation_node/src/translation_util.h b/translation_node/src/translation_util.h index 5a7bcc89..3a896c58 100644 --- a/translation_node/src/translation_util.h +++ b/translation_node/src/translation_util.h @@ -1,3 +1,7 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #pragma once #include "translations.h" @@ -19,20 +23,23 @@ class RegisteredTranslations { } template - void registerTranslation(const std::string& topic_name) { - DirectTranslationData data{}; - data.older = getVersionForMessageType(topic_name); - data.newer = getVersionForMessageType(topic_name); + void registerDirectTranslation(const std::string& topic_name) { + _translations.addTopic(getTopicForMessageType(topic_name)); + _translations.addTopic(getTopicForMessageType(topic_name)); // Translation callbacks - data.translation_cb_from_older = [](const void* older_msg, void* newer_msg) { - T::fromOlder(*(const typename T::MessageOlder*)older_msg, *(typename T::MessageNewer*)newer_msg); + auto translation_cb_from_older = [](const std::vector& older_msg, std::vector& newer_msg) { + T::fromOlder(*(const typename T::MessageOlder*)older_msg[0].get(), *(typename T::MessageNewer*)newer_msg[0].get()); }; - data.translation_cb_to_older = [](const void* newer_msg, void* older_msg) { - T::toOlder(*(const typename T::MessageNewer*)newer_msg, *(typename T::MessageOlder*)older_msg); + auto translation_cb_to_older = [](const std::vector& newer_msg, std::vector& older_msg) { + T::toOlder(*(const typename T::MessageNewer*)newer_msg[0].get(), *(typename T::MessageOlder*)older_msg[0].get()); }; - - _translations.registerDirectTranslation(topic_name, std::move(data)); + _translations.addTranslation({translation_cb_from_older, + {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}}); + _translations.addTranslation({translation_cb_to_older, + {MessageIdentifier{topic_name, T::MessageNewer::MESSAGE_VERSION}}, + {MessageIdentifier{topic_name, T::MessageOlder::MESSAGE_VERSION}}}); } const Translations& translations() const { return _translations; } @@ -41,19 +48,22 @@ class RegisteredTranslations { RegisteredTranslations() = default; template - DirectTranslationData::Version getVersionForMessageType(const std::string& topic_name) { - DirectTranslationData::Version ret{}; + Topic getTopicForMessageType(const std::string& topic_name) { + Topic ret{}; + ret.topic_name = topic_name; ret.version = RosMessageType::MESSAGE_VERSION; - ret.message_buffer = std::static_pointer_cast(std::make_shared()); + auto message_buffer = std::make_shared(); + ret.message_buffer = std::static_pointer_cast(message_buffer); // Subscription/Publication factory methods const std::string topic_name_versioned = getVersionedTopicName(topic_name, ret.version); - ret.subscription_factory = [topic_name_versioned](rclcpp::Node& node, - const std::function& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr { + ret.subscription_factory = [topic_name_versioned, message_buffer](rclcpp::Node& node, + const std::function& on_topic_cb) -> rclcpp::SubscriptionBase::SharedPtr { return std::dynamic_pointer_cast( node.create_subscription(topic_name_versioned, rclcpp::QoS(1).best_effort(), - [on_topic_cb=on_topic_cb](typename RosMessageType::UniquePtr msg) -> void { - on_topic_cb(msg.get()); + [on_topic_cb=on_topic_cb, message_buffer](typename RosMessageType::UniquePtr msg) -> void { + *message_buffer = *msg; + on_topic_cb(); })); }; ret.publication_factory = [topic_name_versioned](rclcpp::Node& node) -> rclcpp::PublisherBase::SharedPtr { @@ -76,15 +86,15 @@ class RegisteredTranslations { }; template -class RegistrationHelper { +class RegistrationHelperDirect { public: - explicit RegistrationHelper(const std::string& topic_name) { - RegisteredTranslations::instance().registerTranslation(topic_name); + explicit RegistrationHelperDirect(const std::string& topic_name) { + RegisteredTranslations::instance().registerDirectTranslation(topic_name); } - RegistrationHelper(RegistrationHelper const&) = delete; - void operator=(RegistrationHelper const&) = delete; + RegistrationHelperDirect(RegistrationHelperDirect const&) = delete; + void operator=(RegistrationHelperDirect const&) = delete; private: }; #define REGISTER_MESSAGE_TRANSLATION_DIRECT(topic_name, class_name) \ - static RegistrationHelper class_name##_registration(topic_name); + static RegistrationHelperDirect class_name##_registration(topic_name); diff --git a/translation_node/src/translations.cpp b/translation_node/src/translations.cpp index 01a3552a..0cca2d5a 100644 --- a/translation_node/src/translations.cpp +++ b/translation_node/src/translations.cpp @@ -1,14 +1,6 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #include "translations.h" -void TranslationForTopic::registerVersion(DirectTranslationData data) { - _direct_translations.emplace_back(std::move(data)); -} - -void Translations::registerDirectTranslation(const std::string &topic_name, DirectTranslationData data) { - auto iter = _topic_translations.find(topic_name); - if (iter == _topic_translations.end()) { - auto [iter_inserted, _] = _topic_translations.emplace(topic_name, TranslationForTopic(topic_name)); - iter = iter_inserted; - } - iter->second.registerVersion(std::move(data)); -} diff --git a/translation_node/src/translations.h b/translation_node/src/translations.h index 73ada193..5b5d0497 100644 --- a/translation_node/src/translations.h +++ b/translation_node/src/translations.h @@ -1,3 +1,7 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #pragma once #include @@ -9,53 +13,42 @@ #include #include "util.h" +#include "graph.h" #include -using DirectTranslationCB = std::function; -using SubscriptionFactoryCB = std::function& on_topic_cb)>; +using TranslationCB = std::function&, std::vector&)>; +using SubscriptionFactoryCB = std::function& on_topic_cb)>; using PublicationFactoryCB = std::function; -struct DirectTranslationData { - struct Version { - MessageVersionType version; - std::shared_ptr message_buffer; - SubscriptionFactoryCB subscription_factory; - PublicationFactoryCB publication_factory; - size_t max_serialized_message_size{}; - }; +struct Topic { + std::string topic_name; + MessageVersionType version{}; - Version older; - Version newer; + SubscriptionFactoryCB subscription_factory; + PublicationFactoryCB publication_factory; - DirectTranslationCB translation_cb_from_older; - DirectTranslationCB translation_cb_to_older; + std::shared_ptr message_buffer; + size_t max_serialized_message_size{}; }; -class TranslationForTopic { -public: - explicit TranslationForTopic(std::string topic_name="") : _topic_name(std::move(topic_name)) {} - - void registerVersion(DirectTranslationData data); - - const std::string& topicName() const { return _topic_name; } - const std::vector& directTranslations() const { return _direct_translations; }; - -private: - const std::string _topic_name; - std::vector _direct_translations; +struct Translation { + TranslationCB cb; + std::vector inputs; + std::vector outputs; }; class Translations { public: - Translations() = default; - void registerDirectTranslation(const std::string& topic_name, DirectTranslationData data); - - const std::unordered_map& topicTranslations() const { return _topic_translations; } + void addTopic(Topic topic) { _topics.push_back(std::move(topic)); } + void addTranslation(Translation translation) { _translations.push_back(std::move(translation)); } + const std::vector& topics() const { return _topics; } + const std::vector& translations() const { return _translations; } private: - std::unordered_map _topic_translations; + std::vector _topics; + std::vector _translations; }; diff --git a/translation_node/src/util.h b/translation_node/src/util.h index a4f61da9..02aad1d3 100644 --- a/translation_node/src/util.h +++ b/translation_node/src/util.h @@ -1,3 +1,7 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #pragma once #include @@ -30,3 +34,19 @@ static inline std::pair getNonVersionedTopicNam } return std::make_pair(non_versioned_topic_name, std::stol(version)); } + +/** + * Get the full topic name, including namespace from a topic name. + * namespace_name should be set to Node::get_effective_namespace() + */ +static inline std::string getFullTopicName(const std::string& namespace_name, const std::string& topic_name) { + std::string full_topic_name = topic_name; + if (!full_topic_name.empty() && full_topic_name[0] != '/') { + if (namespace_name.empty() || namespace_name.back() != '/') { + full_topic_name = '/' + full_topic_name; + } + full_topic_name = namespace_name + full_topic_name; + } + return full_topic_name; +} + diff --git a/translation_node/src/vehicle_attitude_v2.h b/translation_node/src/vehicle_attitude_v2.h index 4675d7ce..eefc2d72 100644 --- a/translation_node/src/vehicle_attitude_v2.h +++ b/translation_node/src/vehicle_attitude_v2.h @@ -1,3 +1,7 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #pragma once // Translate VehicleAttitude v1 <--> v2 diff --git a/translation_node/src/vehicle_attitude_v3.h b/translation_node/src/vehicle_attitude_v3.h index 40c7d811..53fd06d9 100644 --- a/translation_node/src/vehicle_attitude_v3.h +++ b/translation_node/src/vehicle_attitude_v3.h @@ -1,3 +1,7 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ #pragma once // Translate VehicleAttitude v2 <--> v3 diff --git a/translation_node/test/graph.cpp b/translation_node/test/graph.cpp new file mode 100644 index 00000000..7a7f25ef --- /dev/null +++ b/translation_node/test/graph.cpp @@ -0,0 +1,355 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include + + +TEST(graph, basic) +{ + struct NodeData { + bool iterated{false}; + bool translated{false}; + }; + Graph graph; + + const int32_t message1_value = 3; + const int32_t offset = 4; + + // Add 2 nodes + const MessageIdentifier id1{"topic_name", 1}; + auto buffer1 = std::make_shared(); + *buffer1 = message1_value; + EXPECT_TRUE(graph.addNodeIfNotExists(id1, {}, buffer1)); + EXPECT_FALSE(graph.addNodeIfNotExists(id1, {}, std::make_shared())); + const MessageIdentifier id2{"topic_name", 4}; + auto buffer2 = std::make_shared(); + *buffer2 = 773; + EXPECT_TRUE(graph.addNodeIfNotExists(id2, {}, buffer2)); + + // Search nodes + EXPECT_TRUE(graph.findNode(id1).has_value()); + EXPECT_TRUE(graph.findNode(id2).has_value()); + + // Add 1 translation + auto translation_cb = [&offset](const std::vector& a, std::vector& b) { + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value + offset; + }; + graph.addTranslation(translation_cb, {id1}, {id2}); + + // Iteration from id1 must reach id2 + auto node1 = graph.findNode(id1).value(); + auto node2 = graph.findNode(id2).value(); + auto iterate_cb = [](const Graph::MessageNodePtr& node) { + node->data().iterated = true; + }; + graph.iterateBFS(node1, iterate_cb); + EXPECT_TRUE(node1->data().iterated); + EXPECT_TRUE(node2->data().iterated); + node1->data().iterated = false; + node2->data().iterated = false; + + // Iteration from id2 must not reach id1 + graph.iterateBFS(node2, iterate_cb); + EXPECT_FALSE(node1->data().iterated); + EXPECT_TRUE(node2->data().iterated); + + // Test translation + graph.translate(node1, [](auto&& node) { return true; }, + [](auto&& node) { + node->data().translated = true; + }); + EXPECT_TRUE(node1->data().translated); + EXPECT_TRUE(node2->data().translated); + EXPECT_EQ(*buffer1, message1_value); + EXPECT_EQ(*buffer2, message1_value + offset); +} + + +TEST(graph, multi_path) +{ + // Multiple paths with cycles + struct NodeData { + unsigned iterated_idx{0}; + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 6; + std::array ids{{ + {"topic_name", 1}, + {"topic_name", 2}, + {"topic_name", 3}, + {"topic_name", 4}, + {"topic_name", 5}, + {"topic_name", 6}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + // Translations + std::bitset<32> translated; + + auto get_translation_cb = [&translated](unsigned bit) { + auto translation_cb = [&translated, bit](const std::vector &a, std::vector &b) { + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value | (1 << bit); + translated.set(bit); + }; + return translation_cb; + }; + + // Graph: + // ___ 2 -- 3 -- 4 + // | | + // 1 _______| + // | + // 5 + // | + // 6 + + unsigned next_bit = 0; + // Connect each node to the previous and next, except the last 3 + for (unsigned i=0; i < num_nodes - 3; ++i) { + graph.addTranslation(get_translation_cb(next_bit++), {ids[i]}, {ids[i+1]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[i+1]}, {ids[i]}); + } + + // Connect the first to the 3rd as well + graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[2]}, {ids[0]}); + + // Connect the second last to the first one + graph.addTranslation(get_translation_cb(next_bit++), {ids[0]}, {ids[num_nodes-2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[0]}); + + // Connect the second last to the last one + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-1]}, {ids[num_nodes-2]}); + graph.addTranslation(get_translation_cb(next_bit++), {ids[num_nodes-2]}, {ids[num_nodes-1]}); + + unsigned iteration_idx = 1; + graph.iterateBFS(graph.findNode(ids[0]).value(), [&iteration_idx](const Graph::MessageNodePtr& node) { + assert(node->data().iterated_idx == 0); + node->data().iterated_idx = iteration_idx++; + }); + + EXPECT_EQ(graph.findNode(ids[0]).value()->data().iterated_idx, 1); + // We're a bit stricter than we would have to be: ids[1,2,4] would be allowed to have any of the values (2,3,4) + EXPECT_EQ(graph.findNode(ids[1]).value()->data().iterated_idx, 2); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().iterated_idx, 3); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().iterated_idx, 4); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().iterated_idx, 5); + + + // Translation + graph.translate(graph.findNode(ids[0]).value(), + [&](auto&& node) { + // Skip the last 2 nodes + return node.get() != graph.findNode(ids[num_nodes-1]).value().get() && + node.get() != graph.findNode(ids[num_nodes-2]).value().get(); + }, + [](auto&& node) { + node->data().translated = true; + }); + + // Last 2 nodes should not be translated, the rest should be + for (unsigned i =0; i < num_nodes-2; ++i) { + EXPECT_EQ(graph.findNode(ids[i]).value()->data().translated, true); + } + EXPECT_EQ(graph.findNode(ids[num_nodes-2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[num_nodes-1]).value()->data().translated, false); + + // Ensure the correct edges were used for translations + EXPECT_EQ("00000000000000000000000001010001", translated.to_string()); + + // Ensure correct translation path taken for each node (which is stored in the buffers), + // and translation callback got called + EXPECT_EQ(*buffers[0], 0); + EXPECT_EQ(*buffers[1], 0b1); + EXPECT_EQ(*buffers[2], 0b1000000); + EXPECT_EQ(*buffers[3], 0b1010000); + EXPECT_EQ(*buffers[4], 0); + EXPECT_EQ(*buffers[5], 0); + + for (unsigned i=0; i < num_nodes; ++i) { + printf("node[%i]: translated: %i, buffer: %i\n", i, graph.findNode(ids[i]).value()->data().translated, + *buffers[i]); + } +} + +TEST(graph, multi_links) { + // Multiple topics (merging / splitting) + struct NodeData { + bool translated{false}; + }; + Graph graph; + + static constexpr unsigned num_nodes = 6; + std::array ids{{ + {"topic1", 1}, + {"topic2", 1}, + {"topic1", 2}, + {"topic3", 1}, + {"topic4", 1}, + {"topic1", 3}, + }}; + + std::array, num_nodes> buffers{{ + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + std::make_shared(), + }}; + + // Nodes + for (unsigned i = 0; i < num_nodes; ++i) { + EXPECT_TRUE(graph.addNodeIfNotExists(ids[i], {}, buffers[i])); + } + + + // Graph + // ___ + // 1 - | | --- + // | | - 3 - | | - 6 + // 2 - | | --- + // | --- + // | ___ + // --- | | - 4 + // | | - 5 + // --- + + // Translations + auto translation_cb_merge = [](const std::vector &a, std::vector &b) { + assert(a.size() == 2); + assert(b.size() == 1); + auto a_value1 = static_cast(a[0].get()); + auto a_value2 = static_cast(a[1].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value1 | *a_value2; + }; + auto translation_cb_split = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 2); + auto a_value = static_cast(a[0].get()); + auto b_value1 = static_cast(b[0].get()); + auto b_value2 = static_cast(b[1].get()); + *b_value1 = *a_value & 0x0000ffffu; + *b_value2 = *a_value & 0xffff0000u; + }; + auto translation_cb_direct = [](const std::vector &a, std::vector &b) { + assert(a.size() == 1); + assert(b.size() == 1); + auto a_value = static_cast(a[0].get()); + auto b_value = static_cast(b[0].get()); + *b_value = *a_value; + }; + + auto addTranslation = [&](const std::vector& inputs, const std::vector& outputs) { + assert(inputs.size() <= 2); + assert(outputs.size() <= 2); + if (inputs.size() == 1) { + if (outputs.size() == 1) { + graph.addTranslation(translation_cb_direct, inputs, outputs); + graph.addTranslation(translation_cb_direct, outputs, inputs); + } else { + graph.addTranslation(translation_cb_split, inputs, outputs); + graph.addTranslation(translation_cb_merge, outputs, inputs); + } + } else { + assert(outputs.size() == 1); + graph.addTranslation(translation_cb_merge, inputs, outputs); + graph.addTranslation(translation_cb_split, outputs, inputs); + } + }; + addTranslation({ids[0], ids[1]}, {ids[2]}); + addTranslation({ids[1]}, {ids[3], ids[4]}); + addTranslation({ids[2]}, {ids[5]}); + + auto translate_node = [&](const MessageIdentifier& id) { + graph.translate(graph.findNode(id).value(), + [&](auto&& node) { return true; }, + [](auto&& node) { + node->data().translated = true; + }); + + }; + auto reset_translated = [&]() { + for (const auto& id : ids) { + graph.findNode(id).value()->data().translated = false; + } + }; + + // Updating node 2 should trigger an output for nodes 4+5 (splitting) + *buffers[0] = 0xa00000b0; + *buffers[1] = 0x0f00000f; + translate_node(ids[1]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + EXPECT_EQ(*buffers[3], 0x0000000f); + EXPECT_EQ(*buffers[4], 0x0f000000); + + reset_translated(); + + // Now updating node 1 should update nodes 3+6 (merging, both inputs available now) + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(*buffers[2], 0xaf0000bf); + EXPECT_EQ(*buffers[5], 0xaf0000bf); + + reset_translated(); + + // Another update must not trigger any other updates + translate_node(ids[0]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, false); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, false); + + reset_translated(); + + // Backwards: updating node 6 should trigger updates for 1+2, but also 4+5 + *buffers[5] = 0xc00000d0; + translate_node(ids[5]); + EXPECT_EQ(graph.findNode(ids[0]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[1]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[2]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[3]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[4]).value()->data().translated, true); + EXPECT_EQ(graph.findNode(ids[5]).value()->data().translated, true); + EXPECT_EQ(*buffers[0], 0x000000d0); + EXPECT_EQ(*buffers[1], 0xc0000000); + EXPECT_EQ(*buffers[2], 0xc00000d0); + EXPECT_EQ(*buffers[3], 0); + EXPECT_EQ(*buffers[4], 0xc0000000); + EXPECT_EQ(*buffers[5], 0xc00000d0); +} \ No newline at end of file diff --git a/translation_node/test/main.cpp b/translation_node/test/main.cpp new file mode 100644 index 00000000..9d7e484c --- /dev/null +++ b/translation_node/test/main.cpp @@ -0,0 +1,16 @@ +/**************************************************************************** + * Copyright (c) 2024 PX4 Development Team. + * SPDX-License-Identifier: BSD-3-Clause + ****************************************************************************/ + +#include +#include + +int main(int argc, char ** argv) +{ + rclcpp::init(argc, argv); + testing::InitGoogleTest(&argc, argv); + const int ret = RUN_ALL_TESTS(); + rclcpp::shutdown(); + return ret; +}