diff --git a/cmake/GtsamBuildTypes.cmake b/cmake/GtsamBuildTypes.cmake index 2aad58abb8..ab6d149fe9 100644 --- a/cmake/GtsamBuildTypes.cmake +++ b/cmake/GtsamBuildTypes.cmake @@ -178,6 +178,12 @@ foreach(build_type "common" ${GTSAM_CMAKE_CONFIGURATION_TYPES}) append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type}) endforeach() +# Check if timing is enabled and add appropriate definition flag +if(GTSAM_ENABLE_TIMING AND(NOT ${CMAKE_BUILD_TYPE} EQUAL "Timing")) + message(STATUS "Enabling timing for non-timing build") + list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE "ENABLE_TIMING") +endif() + # Linker flags: set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.") set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.") diff --git a/cmake/HandleGeneralOptions.cmake b/cmake/HandleGeneralOptions.cmake index 3282ea6825..e5e2e194ac 100644 --- a/cmake/HandleGeneralOptions.cmake +++ b/cmake/HandleGeneralOptions.cmake @@ -33,6 +33,8 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON) +option(GTSAM_ENABLE_TIMING "Enable the timing tools (gttic/gttoc)" OFF) +option(GTSAM_HYBRID_TIMING "Enable the timing of hybrid factor graph machinery" OFF) option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF) option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) diff --git a/cmake/HandlePrintConfiguration.cmake b/cmake/HandlePrintConfiguration.cmake index 42fae90f77..ac68be20fe 100644 --- a/cmake/HandlePrintConfiguration.cmake +++ b/cmake/HandlePrintConfiguration.cmake @@ -91,6 +91,7 @@ print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory San print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree") +print_enabled_config(${GTSAM_ENABLE_TIMING} "Enable timing machinery") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") diff --git a/gtsam/base/timing.cpp b/gtsam/base/timing.cpp index 07fb7b61bf..9e9d1d1134 100644 --- a/gtsam/base/timing.cpp +++ b/gtsam/base/timing.cpp @@ -31,7 +31,9 @@ namespace gtsam { namespace internal { - + +using ChildOrder = FastMap>; + // a static shared_ptr to TimingOutline with nullptr as the pointer const static std::shared_ptr nullTimingOutline; @@ -91,7 +93,6 @@ void TimingOutline::print(const std::string& outline) const { << n_ << " times, " << wall() << " wall, " << secs() << " children, min: " << min() << " max: " << max() << ")\n"; // Order children - typedef FastMap > ChildOrder; ChildOrder childOrder; for(const ChildMap::value_type& child: children_) { childOrder[child.second->myOrder_] = child.second; @@ -106,6 +107,54 @@ void TimingOutline::print(const std::string& outline) const { #endif } +/* ************************************************************************* */ +void TimingOutline::printCsvHeader(bool addLineBreak) const { +#ifdef GTSAM_USE_BOOST_FEATURES + // Order is (CPU time, number of times, wall time, time + children in seconds, + // min time, max time) + std::cout << label_ + " cpu time (s)" << "," << label_ + " #calls" << "," + << label_ + " wall time(s)" << "," << label_ + " subtree time (s)" + << "," << label_ + " min time (s)" << "," << label_ + "max time(s)" + << ","; + // Order children + ChildOrder childOrder; + for (const ChildMap::value_type& child : children_) { + childOrder[child.second->myOrder_] = child.second; + } + // Print children + for (const ChildOrder::value_type& order_child : childOrder) { + order_child.second->printCsvHeader(); + } + if (addLineBreak) { + std::cout << std::endl; + } + std::cout.flush(); +#endif +} + +/* ************************************************************************* */ +void TimingOutline::printCsv(bool addLineBreak) const { +#ifdef GTSAM_USE_BOOST_FEATURES + // Order is (CPU time, number of times, wall time, time + children in seconds, + // min time, max time) + std::cout << self() << "," << n_ << "," << wall() << "," << secs() << "," + << min() << "," << max() << ","; + // Order children + ChildOrder childOrder; + for (const ChildMap::value_type& child : children_) { + childOrder[child.second->myOrder_] = child.second; + } + // Print children + for (const ChildOrder::value_type& order_child : childOrder) { + order_child.second->printCsv(false); + } + if (addLineBreak) { + std::cout << std::endl; + } + std::cout.flush(); +#endif +} + void TimingOutline::print2(const std::string& outline, const double parentTotal) const { #if GTSAM_USE_BOOST_FEATURES diff --git a/gtsam/base/timing.h b/gtsam/base/timing.h index cc38f32b88..5a96633db1 100644 --- a/gtsam/base/timing.h +++ b/gtsam/base/timing.h @@ -199,6 +199,29 @@ namespace gtsam { #endif GTSAM_EXPORT void print(const std::string& outline = "") const; GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const; + + /** + * @brief Print the CSV header. + * Order is + * (CPU time, number of times, wall time, time + children in seconds, min + * time, max time) + * + * @param addLineBreak Flag indicating if a line break should be added at + * the end. Only used at the top-leve. + */ + GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const; + + /** + * @brief Print the times recursively from parent to child in CSV format. + * For each timing node, the output is + * (CPU time, number of times, wall time, time + children in seconds, min + * time, max time) + * + * @param addLineBreak Flag indicating if a line break should be added at + * the end. Only used at the top-leve. + */ + GTSAM_EXPORT void printCsv(bool addLineBreak = false) const; + GTSAM_EXPORT const std::shared_ptr& child(size_t child, const std::string& label, const std::weak_ptr& thisPtr); GTSAM_EXPORT void tic(); @@ -268,6 +291,14 @@ inline void tictoc_finishedIteration_() { inline void tictoc_print_() { ::gtsam::internal::gTimingRoot->print(); } +// print timing in CSV format +inline void tictoc_printCsv_(bool displayHeader = false) { + if (displayHeader) { + ::gtsam::internal::gTimingRoot->printCsvHeader(true); + } + ::gtsam::internal::gTimingRoot->printCsv(true); +} + // print mean and standard deviation inline void tictoc_print2_() { ::gtsam::internal::gTimingRoot->print2(); } diff --git a/gtsam/config.h.in b/gtsam/config.h.in index 8b4903d3af..58b93ee1ce 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -42,6 +42,9 @@ // Whether to enable merging of equal leaf nodes in the Discrete Decision Tree. #cmakedefine GTSAM_DT_MERGING +// Whether to enable timing in hybrid factor graph machinery +#cmakedefine01 GTSAM_HYBRID_TIMING + // Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake) #cmakedefine GTSAM_USE_TBB diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 383346ab19..a8ec66f730 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -57,6 +57,9 @@ namespace gtsam { AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} + /// Constructor which accepts root pointer + AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {} + // Explicitly non-explicit constructor AlgebraicDecisionTree(const Base& add) : Base(add) {} diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 19dcdc7293..f5ad2b98a4 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature) /* ************************************************************************** */ DiscreteConditional DiscreteConditional::operator*( const DiscreteConditional& other) const { + // If the root is a nullptr, we have a TableDistribution + // TODO(Varun) Revisit this hack after RSS2025 submission + if (!other.root_) { + DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor()); + return dc * (*this); + } + // Take union of frontal keys std::set newFrontals; for (auto&& key : this->frontals()) newFrontals.insert(key); @@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const { return this->operator()(x.discrete()); } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr DiscreteConditional::max( + const Ordering& keys) const { + return BaseFactor::max(keys); +} + +/* ************************************************************************* */ +void DiscreteConditional::prune(size_t maxNrAssignments) { + // Get as DiscreteConditional so the probabilities are normalized + DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments)); + this->root_ = pruned.root_; +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 67f8a0056a..1bca0b09f5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const DiscreteValues& parentsValues) const; + virtual size_t sample(const DiscreteValues& parentsValues) const; /// Single parent version. size_t sample(size_t parent_value) const; @@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Create new factor by maximizing over all + * values with the same separator. + * + * @param keys The keys to sum over. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override; + /// @} /// @name Advanced Interface /// @{ @@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional */ double negLogConstant() const override; + /// Prune the conditional + virtual void prune(size_t maxNrAssignments); + /// @} protected: diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index eb32218193..7e059c5e5d 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -118,17 +118,11 @@ namespace gtsam { // } // } - /** - * @brief Multiply all the `factors`. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DiscreteFactor::shared_ptr - */ - static DiscreteFactor::shared_ptr DiscreteProduct( - const DiscreteFactorGraph& factors) { + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const { // PRODUCT: multiply all factors gttic(product); - DiscreteFactor::shared_ptr product = factors.product(); + DiscreteFactor::shared_ptr product = this->product(); gttoc(product); // Max over all the potentials by pretending all keys are frontal: @@ -145,7 +139,7 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = factors.scaledProduct(); // max out frontals, this is the factor on the separator gttic(max); @@ -223,7 +217,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = factors.scaledProduct(); // sum out frontals, this is the factor on the separator gttic(sum); diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 3d9e86cd17..f4d1a18334 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph /** return product of all factors as a single factor */ DiscreteFactor::shared_ptr product() const; + /** + * @brief Return product of all `factors` as a single factor, + * which is scaled by the max value to prevent underflow + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DiscreteFactor::shared_ptr + */ + DiscreteFactor::shared_ptr scaledProduct() const; + /** * Evaluates the factor graph given values, returns the joint probability of * the factor graph given specific instantiation of values diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp new file mode 100644 index 0000000000..e8696c5b1b --- /dev/null +++ b/gtsam/discrete/TableDistribution.cpp @@ -0,0 +1,174 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableDistribution.cpp + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using std::pair; +using std::stringstream; +using std::vector; +namespace gtsam { + +/// Normalize sparse_table +static Eigen::SparseVector normalizeSparseTable( + const Eigen::SparseVector& sparse_table) { + return sparse_table / sparse_table.sum(); +} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const TableFactor& f) + : BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)), + table_(f / (*std::dynamic_pointer_cast( + f.sum(f.keys().size())))) {} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const DiscreteKeys& keys, + const std::vector& potentials) + : BaseConditional(keys.size(), keys, ADT(nullptr)), + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const DiscreteKeys& keys, + const std::string& potentials) + : BaseConditional(keys.size(), keys, ADT(nullptr)), + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} + +/* ************************************************************************** */ +void TableDistribution::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << " P( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + cout << "):\n"; + table_.print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +bool TableDistribution::equals(const DiscreteFactor& other, double tol) const { + auto dtc = dynamic_cast(&other); + if (!dtc) { + return false; + } else { + const DiscreteConditional& f( + static_cast(other)); + return table_.equals(dtc->table_, tol) && + DiscreteConditional::BaseConditional::equals(f, tol); + } +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const { + return table_.sum(nrFrontals); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const { + return table_.sum(keys); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const { + return table_.max(nrFrontals); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const { + return table_.max(keys); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::operator/( + const DiscreteFactor::shared_ptr& f) const { + return table_ / f; +} + +/* ************************************************************************ */ +DiscreteValues TableDistribution::argmax() const { + uint64_t maxIdx = 0; + double maxValue = 0.0; + + Eigen::SparseVector sparseTable = table_.sparseTable(); + + for (SparseIt it(sparseTable); it; ++it) { + if (it.value() > maxValue) { + maxIdx = it.index(); + maxValue = it.value(); + } + } + + return table_.findAssignments(maxIdx); +} + +/* ****************************************************************************/ +void TableDistribution::prune(size_t maxNrAssignments) { + table_ = table_.prune(maxNrAssignments); +} + +/* ****************************************************************************/ +size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { + static mt19937 rng(2); // random number generator + + DiscreteKeys parentsKeys; + for (auto&& [key, _] : parentsValues) { + parentsKeys.push_back({key, table_.cardinality(key)}); + } + + // Get the correct conditional distribution: P(F|S=parentsValues) + TableFactor pFS = table_.choose(parentsValues, parentsKeys); + + // TODO(Duy): only works for one key now, seems horribly slow this way + if (nrFrontals() != 1) { + throw std::invalid_argument( + "TableDistribution::sample can only be called on single variable " + "conditionals"); + } + Key key = firstFrontalKey(); + size_t nj = cardinality(key); + vector p(nj); + DiscreteValues frontals; + for (size_t value = 0; value < nj; value++) { + frontals[key] = value; + p[value] = pFS(frontals); // P(F=value|S=parentsValues) + if (p[value] == 1.0) { + return value; // shortcut exit + } + } + std::discrete_distribution distribution(p.begin(), p.end()); + return distribution(rng); +} + +} // namespace gtsam diff --git a/gtsam/discrete/TableDistribution.h b/gtsam/discrete/TableDistribution.h new file mode 100644 index 0000000000..72786a515d --- /dev/null +++ b/gtsam/discrete/TableDistribution.h @@ -0,0 +1,177 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableDistribution.h + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * Distribution which uses a SparseVector as the internal + * representation, similar to the TableFactor. + * + * This is primarily used in the case when we have a clique in the BayesTree + * which consists of all the discrete variables, e.g. in hybrid elimination. + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableDistribution : public DiscreteConditional { + private: + TableFactor table_; + + typedef Eigen::SparseVector::InnerIterator SparseIt; + + public: + // typedefs needed to play nice with gtsam + typedef TableDistribution This; ///< Typedef to this class + typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DiscreteConditional + BaseConditional; ///< Typedef to our conditional base class + + using Values = DiscreteValues; ///< backwards compatibility + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + TableDistribution() {} + + /// Construct from TableFactor. + TableDistribution(const TableFactor& f); + + /** + * Construct from DiscreteKeys and std::vector. + */ + TableDistribution(const DiscreteKeys& keys, + const std::vector& potentials); + + /** + * Construct from single DiscreteKey and std::vector. + */ + TableDistribution(const DiscreteKey& key, + const std::vector& potentials) + : TableDistribution(DiscreteKeys(key), potentials) {} + + /** + * Construct from DiscreteKey and std::string. + */ + TableDistribution(const DiscreteKeys& keys, const std::string& potentials); + + /** + * Construct from single DiscreteKey and std::string. + */ + TableDistribution(const DiscreteKey& key, const std::string& potentials) + : TableDistribution(DiscreteKeys(key), potentials) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Table Distribution: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// GTSAM-style equals + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + /// @} + /// @name Standard Interface + /// @{ + + /// Return the underlying TableFactor + TableFactor table() const { return table_; } + + using BaseConditional::evaluate; // HybridValues version + + /// Evaluate the conditional given the values. + virtual double evaluate(const Assignment& values) const override { + return table_.evaluate(values); + } + + /// Create new factor by summing all values with the same separator values + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override; + + /// Create new factor by summing all values with the same separator values + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override; + + /// Create new factor by maximizing over all values with the same separator. + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override; + + /// Create new factor by maximizing over all values with the same separator. + DiscreteFactor::shared_ptr max(const Ordering& keys) const override; + + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + + /** + * @brief Return assignment that maximizes value. + * + * @return maximizing assignment for the variables. + */ + DiscreteValues argmax() const; + + /** + * sample + * @param parentsValues Known values of the parents + * @return sample from conditional + */ + virtual size_t sample(const DiscreteValues& parentsValues) const override; + + /// @} + /// @name Advanced Interface + /// @{ + + /// Prune the conditional + virtual void prune(size_t maxNrAssignments) override; + + /// Get a DecisionTreeFactor representation. + DecisionTreeFactor toDecisionTreeFactor() const override { + return table_.toDecisionTreeFactor(); + } + + /// Get the number of non-zero values. + uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); } + + /// @} + + private: +#if GTSAM_ENABLE_BOOST_SERIALIZATION + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar& BOOST_SERIALIZATION_NVP(table_); + } +#endif +}; +// TableDistribution + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 5a804b6a61..1cb9eda8b9 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } + public: /** * Convert probability table given as doubles to SparseVector. * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} @@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { static Eigen::SparseVector Convert(const DiscreteKeys& keys, const std::string& table); - public: // typedefs needed to play nice with gtsam typedef TableFactor This; typedef DiscreteFactor Base; ///< Typedef to base class @@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DecisionTreeFactor toDecisionTreeFactor() const override; /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, + TableFactor choose(const DiscreteValues parentAssignments, DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b2e2524f8b..40f1822cf2 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { std::vector pmf() const; }; +#include +virtual class TableFactor : gtsam::DiscreteFactor { + TableFactor(); + TableFactor(const gtsam::DiscreteKeys& keys, + const gtsam::TableFactor& potentials); + TableFactor(const gtsam::DiscreteKeys& keys, std::vector& table); + TableFactor(const gtsam::DiscreteKeys& keys, string spec); + TableFactor(const gtsam::DiscreteKeys& keys, + const gtsam::DecisionTreeFactor& dtf); + TableFactor(const gtsam::DecisionTreeFactor& dtf); + + void print(string s = "TableFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + double evaluate(const gtsam::DiscreteValues& values) const; + double error(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class TableDistribution : gtsam::DiscreteConditional { + TableDistribution(); + TableDistribution(const gtsam::TableFactor& f); + TableDistribution(const gtsam::DiscreteKey& key, std::vector spec); + TableDistribution(const gtsam::DiscreteKeys& keys, std::vector spec); + TableDistribution(const gtsam::DiscreteKeys& keys, string spec); + TableDistribution(const gtsam::DiscreteKey& key, string spec); + + void print(string s = "Table Distribution\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + gtsam::TableFactor table() const; + double evaluate(const gtsam::DiscreteValues& values) const; + size_t nrValues() const; +}; + #include class DiscreteBayesNet { DiscreteBayesNet(); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 623b82eea7..8668bedd6f 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { joint = joint * (*conditional); } - // Prune the joint. NOTE: again, possibly quite expensive. - const DecisionTreeFactor pruned = joint.prune(maxNrLeaves); - - // Create a the result starting with the pruned joint. + // Create the result starting with the pruned joint. HybridBayesNet result; - result.emplace_shared(pruned.size(), pruned); + result.emplace_shared(joint); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + result.back()->asDiscrete()->prune(maxNrLeaves); + + // Get pruned discrete probabilities so + // we can prune HybridGaussianConditionals. + DiscreteConditional pruned = *result.back()->asDiscrete(); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree @@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const { for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_fg.push_back(conditional->asDiscrete()); + if (auto dtc = conditional->asDiscrete()) { + // The number of keys should be small so should not + // be expensive to convert to DiscreteConditional. + discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(), + dtc->toDecisionTreeFactor())); + } else { + discrete_fg.push_back(conditional->asDiscrete()); + } } } diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 1b633e024d..31d256d6fd 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +DiscreteValues HybridBayesTree::discreteMaxProduct( + const DiscreteFactorGraph& dfg) const { + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); + + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + DiscreteValues assignment = TableDistribution(p).argmax(); + return assignment; +} + /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { DiscreteFactorGraph discrete_fg; @@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const { // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - discrete_fg.push_back(root_conditional->asDiscrete()); - mpe = discrete_fg.optimize(); + auto discrete = root_conditional->asDiscrete(); + discrete_fg.push_back(discrete); + mpe = discreteMaxProduct(discrete_fg); } else { throw std::runtime_error( "HybridBayesTree root is not discrete-only. Please check elimination " @@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); + auto prunedDiscreteProbs = + this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - discreteProbs->root_ = prunedDiscreteProbs.root_; + // Imperative pruning + prunedDiscreteProbs->prune(maxNrLeaves); /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteProbs; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, + DiscreteConditional::shared_ptr prunedDiscreteProbs; + HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (!hybridGaussianCond->pruned()) { // Imperative clique->conditional() = std::make_shared( - hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); + hybridGaussianCond->prune(*parentData.prunedDiscreteProbs)); } } return parentData; diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 06d880f02a..ec29f7b1ee 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /// @} private: + /// Helper method to compute the max product assignment + /// given a DiscreteFactorGraph + DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const; + #if GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index a08b3a6ee5..3cf5b80e59 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional } /** - * @brief Return conditional as a DiscreteConditional + * @brief Return conditional as a DiscreteConditional or specified type T. * @return nullptr if not a DiscreteConditional * @return DiscreteConditional::shared_ptr */ - DiscreteConditional::shared_ptr asDiscrete() const { - return std::dynamic_pointer_cast(inner_); + template + typename T::shared_ptr asDiscrete() const { + return std::dynamic_pointer_cast(inner_); } /// Get the type-erased pointer to the inner type diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 54346679ee..78e1f5324e 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -304,7 +304,7 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( - const DecisionTreeFactor &discreteProbs) const { + const DiscreteConditional &discreteProbs) const { // Find keys in discreteProbs.keys() but not in this->keys(): std::set mine(this->keys().begin(), this->keys().end()); std::set theirs(discreteProbs.keys().begin(), diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index e769662ed1..3b95e0277f 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional */ HybridGaussianConditional::shared_ptr prune( - const DecisionTreeFactor &discreteProbs) const; + const DiscreteConditional &discreteProbs) const; /// Return true if the conditional has already been pruned. bool pruned() const { return pruned_; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index b9051554a4..cf56b52ed5 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,12 +20,13 @@ #include #include -#include #include #include #include #include #include +#include +#include #include #include #include @@ -241,29 +242,29 @@ continuousElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, - * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return DecisionTreeFactor::shared_ptr + * @return TableFactor::shared_ptr */ -static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( +static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); - return std::make_shared(discreteKeys, potentials); + return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ -static std::pair> -discreteElimination(const HybridGaussianFactorGraph &factors, - const Ordering &frontalKeys) { +static DiscreteFactorGraph CollectDiscreteFactors( + const HybridGaussianFactorGraph &factors) { DiscreteFactorGraph dfg; for (auto &f : factors) { if (auto df = dynamic_pointer_cast(f)) { dfg.push_back(df); + } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute a discrete factor from the remaining error. @@ -282,16 +283,73 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto hc = dynamic_pointer_cast(f)) { auto dc = hc->asDiscrete(); if (!dc) throwRuntimeError("discreteElimination", dc); - dfg.push_back(dc); +#if GTSAM_HYBRID_TIMING + gttic_(ConvertConditionalToTableFactor); +#endif + if (auto dtc = std::dynamic_pointer_cast(dc)) { + /// Get the underlying TableFactor + dfg.push_back(dtc->table()); + } else { + // Convert DiscreteConditional to TableFactor + auto tdc = std::make_shared(*dc); + dfg.push_back(tdc); + } +#if GTSAM_HYBRID_TIMING + gttoc_(ConvertConditionalToTableFactor); +#endif } else { throwRuntimeError("discreteElimination", f); } } - // NOTE: This does sum-product. For max-product, use EliminateForMPE. - auto result = EliminateDiscrete(dfg, frontalKeys); + return dfg; +} - return {std::make_shared(result.first), result.second}; +/* ************************************************************************ */ +static std::pair> +discreteElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscrete); +#endif + // Check if separator is empty. + // This is the same as checking if the number of frontal variables + // is the same as the number of variables in the DiscreteFactorGraph. + // If the separator is empty, we have a clique of all the discrete variables + // so we can use the TableFactor for efficiency. + if (frontalKeys.size() == dfg.keys().size()) { + // Get product factor + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteFormDiscreteConditional); +#endif + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + auto conditional = std::make_shared(p); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteFormDiscreteConditional); +#endif + + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); + + return {std::make_shared(conditional), sum}; + + } else { + // Perform sum-product. + auto result = EliminateDiscrete(dfg, frontalKeys); + return {std::make_shared(result.first), result.second}; + } +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscrete); +#endif } /* ************************************************************************ */ @@ -319,8 +377,19 @@ static std::shared_ptr createDiscreteFactor( } }; +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteBoundaryErrors); +#endif AlgebraicDecisionTree errors(eliminationResults, calculateError); - return DiscreteFactorFromErrors(discreteSeparator, errors); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteBoundaryErrors); + gttic_(DiscreteBoundaryResult); +#endif + auto result = DiscreteFactorFromErrors(discreteSeparator, errors); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteBoundaryResult); +#endif + return result; } /* *******************************************************************************/ @@ -360,12 +429,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // the discrete separator will be *all* the discrete keys. DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); +#if GTSAM_HYBRID_TIMING + gttic_(HybridCollectProductFactor); +#endif // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. Just like any hybrid // factor, every assignment also has a scalar error, in this case the sum of // all errors in the graph. This error is assignment-specific and accounts for // any difference in noise models used. HybridGaussianProductFactor productFactor = collectProductFactor(); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridCollectProductFactor); +#endif // Check if a factor is null auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; @@ -393,8 +468,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { return {conditional, conditional->negLogConstant(), factor, scalar}; }; +#if GTSAM_HYBRID_TIMING + gttic_(HybridEliminate); +#endif // Perform elimination! const ResultTree eliminationResults(productFactor, eliminate); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridEliminate); +#endif // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a HybridGaussianFactor diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 28116df45d..f99d95c018 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal( elimination_ordering, function, std::cref(index)); if (maxNrLeaves) { +#if GTSAM_HYBRID_TIMING + gttic_(HybridBayesTreePrune); +#endif bayesTree->prune(*maxNrLeaves); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridBayesTreePrune); +#endif } // Re-add into Bayes tree data structures diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index 29e467d866..3b4856dfbc 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -15,6 +15,7 @@ * @author Varun Agrawal */ +#include #include #include #include @@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() { // Obtain the new linearization point const Values newLinPoint = estimate(); - auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete()); + DiscreteConditional::shared_ptr discreteProbabilities; + + auto discreteRoot = isam_.roots().at(0)->conditional(); + if (discreteRoot->asDiscrete()) { + discreteProbabilities = discreteRoot->asDiscrete(); + } else { + discreteProbabilities = discreteRoot->asDiscrete(); + } isam_.clear(); @@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() { HybridNonlinearFactorGraph pruned_factors; for (auto&& factor : factors_) { if (auto nf = std::dynamic_pointer_cast(factor)) { - pruned_factors.push_back(nf->prune(discreteProbs)); + pruned_factors.push_back(nf->prune(*discreteProbabilities)); } else { pruned_factors.push_back(factor); } diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 14bef5fbb4..c98485feaa 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { double midway = mu1 - mu0; auto eliminationResult = gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential(); - auto pMid = *eliminationResult->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid)); + auto pMid = eliminationResult->at(0)->asDiscrete(); + EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid)); // Everywhere else, the result should be a sigmoid. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = + *eliminationResult1->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)}); hfg1.push_back(mixing); auto eliminationResult2 = hfg1.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = + *eliminationResult2->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } @@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Eliminate the graph! auto eliminationResultMax = gfg.eliminateSequential(); - // Equality of posteriors asserts that the elimination is correct (same ratios - // for all modes) + // Equality of posteriors asserts that the elimination is correct + // (same ratios for all modes) EXPECT(assert_equal(expectedDiscretePosterior, eliminationResultMax->discretePosterior(vv))); - auto pMax = *eliminationResultMax->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4)); + auto pMax = *eliminationResultMax->at(0)->asDiscrete(); + EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4)); // Everywhere else, the result should be a bell curve like function. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = + *eliminationResult1->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -158,10 +162,12 @@ TEST(GaussianMixture, GaussianMixtureModel2) { m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)}); hfg.push_back(mixing); auto eliminationResult2 = hfg.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = + *eliminationResult2->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 135da5dc73..989694b269 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { } size_t maxNrLeaves = 3; - auto prunedDecisionTree = joint.prune(maxNrLeaves); + DiscreteConditional prunedDecisionTree(joint); + prunedDecisionTree.prune(maxNrLeaves); #ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 5b27e2b417..ef2ae9c41b 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include #include @@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) { // Create expected discrete conditional on m0. DiscreteKey m(M(0), 2); - DiscreteConditional expected(m % "0.51341712/1"); // regression + TableDistribution expected(m, "0.51341712 1"); // regression // Eliminate into BN using one ordering const Ordering ordering1{X(0), X(1), M(0)}; HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1); // Check that the discrete conditional matches the expected. - auto dc1 = bn1->back()->asDiscrete(); + auto dc1 = bn1->back()->asDiscrete(); EXPECT(assert_equal(expected, *dc1, 1e-9)); // Eliminate into BN using a different ordering @@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) { HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2); // Check that the discrete conditional matches the expected. - auto dc2 = bn2->back()->asDiscrete(); + auto dc2 = bn2->back()->asDiscrete(); EXPECT(assert_equal(expected, *dc2, 1e-9)); } diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 350bc91848..8bb83cac48 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) { potentials[i] = 1; const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); @@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 31e36101b2..c8735c40a9 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -114,10 +115,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) { EXPECT(HybridConditional::CheckInvariants(*result.first, values)); // Check that factor is discrete and correct - auto factor = std::dynamic_pointer_cast(result.second); + auto factor = std::dynamic_pointer_cast(result.second); CHECK(factor); // regression test - EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5)); + EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5)); } /* ************************************************************************* */ @@ -329,7 +330,7 @@ TEST(HybridBayesNet, Switching) { // Check the remaining factor for x1 CHECK(factor_x1); - auto phi_x1 = std::dynamic_pointer_cast(factor_x1); + auto phi_x1 = std::dynamic_pointer_cast(factor_x1); CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because @@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { mode, std::vector{conditional0, conditional1}); // Add prior on mode. - expectedBayesNet.emplace_shared(mode, "74/26"); + expectedBayesNet.emplace_shared(mode, "74 26"); // Test elimination const auto posterior = fg.eliminateSequential(); @@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) { m1, std::vector{conditional0, conditional1}); // Add prior on m1. - expectedBayesNet.emplace_shared(m1, "1/1"); + expectedBayesNet.emplace_shared(m1, "0.188638 0.811362"); // Test elimination const auto posterior = fg.eliminateSequential(); - // EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(ratioTest(bn, measurements, *posterior)); @@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { mode, std::vector{conditional0, conditional1}); // Add prior on mode. - expectedBayesNet.emplace_shared(mode, "23/77"); + // Since this is the only discrete conditional, it is added as a + // TableDistribution. + expectedBayesNet.emplace_shared(mode, "23 77"); // Test elimination const auto posterior = fg.eliminateSequential(); diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 04b44f9041..54964f6f7c 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) { expectedRemainingGraph->eliminateMultifrontal(discreteOrdering); // Test the probability values with regression tests. - auto discrete = isam[M(1)]->conditional()->asDiscrete(); + auto discrete = isam[M(1)]->conditional()->asDiscrete(); EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5)); @@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) { 1 1 1 Leaf 0.5 */ - auto discreteConditional_m0 = *dynamic_pointer_cast( + auto discreteConditional_m0 = *dynamic_pointer_cast( incrementalHybrid[M(0)]->conditional()->inner()); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); - // Get the number of elements which are greater than 0. - auto count = [](const double &value, int count) { - return value > 0 ? count + 1 : count; - }; // Check that the number of leaves after pruning is 5. - EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); + EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues()); // Check that the hybrid nodes of the bayes net match those of the pre-pruning // bayes net, at the same positions. @@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) { // Test if the optimal discrete mode assignment is (1, 1, 1). DiscreteFactorGraph discreteGraph; - discreteGraph.push_back(discreteTree); + // discreteTree is a TableDistribution, so we convert to + // DecisionTreeFactor for the DiscreteFactorGraph + discreteGraph.push_back(discreteTree->toDecisionTreeFactor()); DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues expected_assignment; diff --git a/gtsam/hybrid/tests/testHybridMotionModel.cpp b/gtsam/hybrid/tests/testHybridMotionModel.cpp index 747a1b6883..a4de6a17bd 100644 --- a/gtsam/hybrid/tests/testHybridMotionModel.cpp +++ b/gtsam/hybrid/tests/testHybridMotionModel.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) { // Since no measurement on x1, we hedge our bets // Importance sampling run with 100k samples gives 50.051/49.949 // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "50/50"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()))); + TableDistribution expected(m1, "50 50"); + EXPECT( + assert_equal(expected, *(bn->at(2)->asDiscrete()))); } { @@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) { // Since we have a measurement on x1, we get a definite result // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "44.3854/55.6146"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "44.3854 55.6146"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } } @@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "48.3158/51.6842"); - EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "48.3158 51.6842"); + EXPECT(assert_equal( + expected, *(eliminated->at(2)->asDiscrete()), 0.02)); } { @@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "55.396/44.604"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "55.396 44.604"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } } @@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "51.7762/48.2238"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "51.7762 48.2238"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } { @@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "49.0762/50.9238"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005)); + TableDistribution expected(m1, "49.0762 50.9238"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.05)); } } @@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "8.91527/91.0847"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "8.91527 91.0847"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.01)); } /* ************************************************************************* */ @@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) { DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - DiscreteConditional expected_m1(m1, "0.5/0.5"); - DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); + TableDistribution expected_m1(m1, "0.5 0.5"); + TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete()); EXPECT(assert_equal(expected_m1, actual_m1)); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 07f70e95c4..3df03021b7 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -368,10 +369,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents()); // This is now a discreteFactor - auto discreteFactor = dynamic_pointer_cast(factorOnModes); + auto discreteFactor = dynamic_pointer_cast(factorOnModes); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); - EXPECT(discreteFactor->root_->isLeaf() == false); } /**************************************************************************** @@ -513,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) { // P(m1) EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)}); EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents()); - EXPECT( - dynamic_pointer_cast(hybridBayesNet->at(4)->inner()) - ->equals(*discreteBayesNet.at(1))); + TableDistribution dtc = + *hybridBayesNet->at(4)->asDiscrete(); + EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor()) + .equals(*discreteBayesNet.at(1))); } /**************************************************************************** @@ -1062,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) { DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - DiscreteConditional expected_m1(m1, "0.5/0.5"); - DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); + TableDistribution expected_m1(m1, "0.5 0.5"); + TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete()); EXPECT(assert_equal(expected_m1, actual_m1)); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 67cec83199..b32860cffb 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) { 1 1 1 Leaf 0.5 */ - auto discreteConditional_m0 = *dynamic_pointer_cast( + auto discreteConditional_m0 = *dynamic_pointer_cast( bayesTree[M(0)]->conditional()->inner()); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); - // Get the number of elements which are greater than 0. - auto count = [](const double &value, int count) { - return value > 0 ? count + 1 : count; - }; // Check that the number of leaves after pruning is 5. - EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); + EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues()); // Check that the hybrid nodes of the bayes net match those of the pre-pruning // bayes net, at the same positions. @@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete(); + auto discreteTree = + bayesTree[M(3)]->conditional()->asDiscrete(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1). DiscreteFactorGraph discreteGraph; - discreteGraph.push_back(discreteTree); + discreteGraph.push_back(discreteTree->toDecisionTreeFactor()); DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues expected_assignment; diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index 3be96b7512..9aabe309bc 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); +BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); using ADT = AlgebraicDecisionTree; diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 6d609deb03..6edab34494 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -13,14 +13,14 @@ import unittest import numpy as np -from gtsam.symbol_shorthand import C, M, X, Z -from gtsam.utils.test_case import GtsamTestCase import gtsam -from gtsam import (DiscreteConditional, GaussianConditional, - HybridBayesNet, HybridGaussianConditional, - HybridGaussianFactor, HybridGaussianFactorGraph, - HybridValues, JacobianFactor, noiseModel) +from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet, + HybridGaussianConditional, HybridGaussianFactor, + HybridGaussianFactorGraph, HybridValues, JacobianFactor, + TableDistribution, noiseModel) +from gtsam.symbol_shorthand import C, M, X, Z +from gtsam.utils.test_case import GtsamTestCase DEBUG_MARGINALS = False @@ -51,7 +51,7 @@ def test_create(self): self.assertEqual(len(hybridCond.keys()), 2) discrete_conditional = hbn.at(hbn.size() - 1).inner() - self.assertIsInstance(discrete_conditional, DiscreteConditional) + self.assertIsInstance(discrete_conditional, TableDistribution) def test_optimize(self): """Test construction of hybrid factor graph."""