Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Timing #1948

Merged
merged 138 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
4437baf
expose GTSAM_ENABLE_TIMING
varunagrawal Dec 18, 2024
53cf49b
code to print timing as CSV
varunagrawal Dec 23, 2024
ad3967e
Merge branch 'develop' into hybrid-timing
varunagrawal Dec 23, 2024
1c80fa1
Merge branch 'develop' into hybrid-timing
varunagrawal Dec 23, 2024
05e01b1
remove extra space after comma
varunagrawal Dec 24, 2024
2b9ab0e
Merge branch 'discrete-improvements' into hybrid-timing
varunagrawal Dec 27, 2024
7f42093
Merge branch 'develop' into hybrid-timing
varunagrawal Dec 27, 2024
4cf0727
Merge branch 'discrete-improvements' into hybrid-timing
varunagrawal Dec 27, 2024
7c9d04f
conditional switch for hybrid timing
varunagrawal Dec 27, 2024
4d96af7
update config.h.in
varunagrawal Dec 27, 2024
02d461e
make a cmake flag
varunagrawal Dec 31, 2024
34fba68
use TableFactor instead of DecisionTreeFactor in discrete elimination
varunagrawal Dec 27, 2024
de652ea
initial DiscreteTableConditional
varunagrawal Dec 30, 2024
b57e448
DiscreteConditional evaluate method for conditionals
varunagrawal Dec 31, 2024
d18f23c
setData method
varunagrawal Dec 31, 2024
4ff7014
use a TableFactor as the underlying data representation for DiscreteT…
varunagrawal Dec 31, 2024
b39b200
fix return type
varunagrawal Dec 31, 2024
d9faa82
add evaluate and getter
varunagrawal Dec 31, 2024
60945c8
add override methods to DiscreteTableConditional
varunagrawal Dec 31, 2024
e46e9d6
use DiscreteTableConditional in EliminateDiscrete
varunagrawal Dec 31, 2024
b7b2734
small cleanup
varunagrawal Dec 31, 2024
214043d
use DiscreteConditional shared_ptr for dynamic dispatch
varunagrawal Dec 31, 2024
dfec840
use TableFactor for discrete elimination
varunagrawal Dec 31, 2024
5019153
small cleanup
varunagrawal Dec 31, 2024
623bd63
fix hybrid tests
varunagrawal Dec 31, 2024
9f85d4c
fix equals
varunagrawal Dec 31, 2024
9cacb98
undo changes to DiscreteFactorGraph
varunagrawal Dec 31, 2024
c6e9bfc
remove unused methods
varunagrawal Dec 31, 2024
f95ae52
Use TableFactor everywhere in hybrid elimination
varunagrawal Dec 31, 2024
71ea8c5
fix tests
varunagrawal Dec 31, 2024
42f8e54
customize discrete elimination in Hybrid
varunagrawal Dec 31, 2024
47e76ff
remove GTSAM_HYBRID_TIMING guards
varunagrawal Dec 31, 2024
0820fcb
fix types
varunagrawal Dec 31, 2024
34eb0fc
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Dec 31, 2024
462a5b8
return DiscreteTableConditional from hybrid elimination
varunagrawal Dec 31, 2024
5e1931e
update testGaussianMixture
varunagrawal Dec 31, 2024
3119d13
remove evaluate method
varunagrawal Dec 31, 2024
9e1c0d7
fix constructor and equals
varunagrawal Dec 31, 2024
094b76d
fix bug in TableFactor when trying to convert to DecisionTreeFactor
varunagrawal Dec 31, 2024
bf4c0bd
fix creation of DiscreteConditional
varunagrawal Dec 31, 2024
0e2e8bb
full discrete elimination
varunagrawal Dec 31, 2024
73f5408
normalize
varunagrawal Jan 1, 2025
30670ab
Merge pull request #1954 from borglab/hybrid-with-tablefactor
varunagrawal Jan 1, 2025
a71008d
new helper constructor for DiscreteConditional
varunagrawal Jan 1, 2025
57c426a
simplify discrete conditional computation
varunagrawal Jan 1, 2025
ffa40f7
small fix
varunagrawal Jan 1, 2025
ab47ade
fix empty keys case
varunagrawal Jan 1, 2025
6e4d1fa
rename
varunagrawal Jan 1, 2025
022ed50
move common typedef to top
varunagrawal Jan 1, 2025
782f39a
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 1, 2025
e854d15
evaluate needed for correct test results
varunagrawal Jan 1, 2025
ec5d87e
custom discreteMaxProduct
varunagrawal Jan 1, 2025
2a5833b
custom ProductAndNormalize for TableFactor
varunagrawal Jan 1, 2025
6f19ffd
fixed maxProduct
varunagrawal Jan 1, 2025
e49b40b
remove TableFactor check for another day
varunagrawal Jan 2, 2025
bb4ee20
custom path for empty separator
varunagrawal Jan 2, 2025
2894c95
clarify TableProduct function
varunagrawal Jan 2, 2025
d22ba29
remove DiscreteConditional constructor since we no longer use it
varunagrawal Jan 2, 2025
cc237a2
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 2, 2025
e56fac2
fix TableProduct name
varunagrawal Jan 2, 2025
26e1f08
fix testGaussianMixture
varunagrawal Jan 2, 2025
c7c42af
undo HybridBayesNet changes
varunagrawal Jan 2, 2025
27e3a04
fix testHybridGaussianFactorGraph
varunagrawal Jan 2, 2025
cafac63
fix to use DiscreteTableConditional
varunagrawal Jan 2, 2025
35502f3
custom max-product for HybridBayesTree
varunagrawal Jan 2, 2025
62a6558
fix discreteMaxProduct declaration
varunagrawal Jan 2, 2025
5d2d879
make asDiscrete a template
varunagrawal Jan 2, 2025
4c5b842
add checks
varunagrawal Jan 2, 2025
e620729
fix testHybridEstimation
varunagrawal Jan 2, 2025
d18569b
fix testGaussianMixture
varunagrawal Jan 2, 2025
769e2c7
fix testHybridMotionModel
varunagrawal Jan 2, 2025
da22055
formatting
varunagrawal Jan 2, 2025
fcc56f5
fix pruning test in testHybridBayesNet
varunagrawal Jan 2, 2025
f80a3a1
fix testHybridGaussianFactorGraph
varunagrawal Jan 2, 2025
b343a80
more helper methods in DiscreteTableConditional
varunagrawal Jan 2, 2025
e6db6d1
cleaner API
varunagrawal Jan 2, 2025
fd2820e
fix testHybridNonlinearFactorGraph
varunagrawal Jan 2, 2025
0518b60
Merge branch 'develop' into hybrid-timing
varunagrawal Jan 2, 2025
02d9959
small fix
varunagrawal Jan 2, 2025
113492f
separate function to collect discrete factors
varunagrawal Jan 2, 2025
32317d4
simplify empty separator check
varunagrawal Jan 2, 2025
49b74af
Merge branch 'develop' into hybrid-timing
varunagrawal Jan 2, 2025
7440c19
Merge branch 'hybrid-timing' into hybrid-custom-discrete
varunagrawal Jan 2, 2025
446263c
Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
varunagrawal Jan 2, 2025
b9293b4
fix testHybridGaussianISAM
varunagrawal Jan 3, 2025
8e36361
fix testHybridNonlinearISAM
varunagrawal Jan 3, 2025
62a35c0
serialize table inside TableDistribution
varunagrawal Jan 3, 2025
73f98d8
Merge pull request #1955 from borglab/hybrid-custom-discrete
varunagrawal Jan 3, 2025
0302075
serialize functions for Eigen::SparseVector
varunagrawal Jan 3, 2025
92b5bb1
add serialization code to TableFactor
varunagrawal Jan 3, 2025
5041468
test for TableFactor serialization
varunagrawal Jan 3, 2025
b28cae2
use string based constructor
varunagrawal Jan 3, 2025
0098112
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal Jan 3, 2025
9b1918c
rename from DiscreteTableConditional to TableDistribution
varunagrawal Jan 3, 2025
e1628e3
rename source files
varunagrawal Jan 3, 2025
83bb404
export TableDistribution for serialization
varunagrawal Jan 3, 2025
2e06954
improved docstring
varunagrawal Jan 3, 2025
35e1e61
kill operator* method
varunagrawal Jan 3, 2025
bc449c1
formatting
varunagrawal Jan 3, 2025
bd30bef
remove constructors that need parents
varunagrawal Jan 4, 2025
f9e3280
add helpful constructors
varunagrawal Jan 4, 2025
3abff90
fix tests
varunagrawal Jan 4, 2025
11a740e
use template
varunagrawal Jan 4, 2025
b7bddde
fix TableDistribution constructor call
varunagrawal Jan 4, 2025
d6bc1e1
pass DiscreteConditional& for pruning instead of shared_ptr
varunagrawal Jan 4, 2025
9a40be6
normalize values in sparse_table so it forms a proper distribution
varunagrawal Jan 4, 2025
7cb8181
fix TableDistribution constructors in tests
varunagrawal Jan 4, 2025
d39641d
get rid of setData and make prune() imperative for non-factors
varunagrawal Jan 4, 2025
d378015
update pruning in BayesNet and BayesTree
varunagrawal Jan 4, 2025
14f3254
update test
varunagrawal Jan 4, 2025
5a8a942
add argmax method to TableDistribution
varunagrawal Jan 4, 2025
2410d4f
use TableDistribution::argmax in discreteMaxProduct
varunagrawal Jan 4, 2025
5e4cf89
max returns DiscreteFactor
varunagrawal Jan 4, 2025
ffc20f8
wrap TableDistribution
varunagrawal Jan 4, 2025
e9abd5c
wrap TableFactor
varunagrawal Jan 4, 2025
9a356f1
typo fix
varunagrawal Jan 4, 2025
aba691d
fix python test
varunagrawal Jan 4, 2025
69b5e7d
return DiscreteValues directly
varunagrawal Jan 4, 2025
07a6829
code cleanup
varunagrawal Jan 4, 2025
bcc52be
emplace then prune
varunagrawal Jan 4, 2025
77f3874
remove deleted constructors
varunagrawal Jan 4, 2025
edef8c8
Merge branch 'develop' into hybrid-timing
varunagrawal Jan 7, 2025
8658f25
Merge branch 'hybrid-timing' into discrete-table-conditional
varunagrawal Jan 7, 2025
5913fd1
updates to get things working
varunagrawal Jan 7, 2025
90825b9
remove hybrid timing flag from DiscreteFactorGraph
varunagrawal Jan 7, 2025
82dba63
new scaledProduct method instead of DiscreteProduct
varunagrawal Jan 7, 2025
9960f2d
kill TableProduct in favor of DiscreteFactorGraph::scaledProduct
varunagrawal Jan 7, 2025
96a136b
override sum and max in TableDistribution
varunagrawal Jan 7, 2025
3fb6f39
override operator/ in TableDistribution
varunagrawal Jan 7, 2025
3d2dd7c
update scaledProduct docs
varunagrawal Jan 7, 2025
9228f0f
fix headers
varunagrawal Jan 7, 2025
b81ab86
make ADT with nullptr in TableDistribution
varunagrawal Jan 7, 2025
3629c33
override sample in TableDistribution
varunagrawal Jan 7, 2025
9dfdf55
add hack to multiply DiscreteConditional with TableDistribution
varunagrawal Jan 7, 2025
9c2ecc3
simplify multiplication
varunagrawal Jan 7, 2025
4fc2387
fix relinearization in HybridNonlinearISAM
varunagrawal Jan 7, 2025
3ecc232
fix tests
varunagrawal Jan 7, 2025
8cf2123
Merge pull request #1953 from borglab/discrete-table-conditional
varunagrawal Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/GtsamBuildTypes.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ foreach(build_type "common" ${GTSAM_CMAKE_CONFIGURATION_TYPES})
append_config_if_not_empty(GTSAM_COMPILE_DEFINITIONS_PUBLIC ${build_type})
endforeach()

# Check if timing is enabled and add appropriate definition flag
if(GTSAM_ENABLE_TIMING AND(NOT ${CMAKE_BUILD_TYPE} EQUAL "Timing"))
message(STATUS "Enabling timing for non-timing build")
list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE "ENABLE_TIMING")
endif()

# Linker flags:
set(GTSAM_CMAKE_SHARED_LINKER_FLAGS_TIMING "${CMAKE_SHARED_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
set(GTSAM_CMAKE_MODULE_LINKER_FLAGS_TIMING "${CMAKE_MODULE_LINKER_FLAGS_RELEASE}" CACHE STRING "Linker flags during timing builds.")
Expand Down
2 changes: 2 additions & 0 deletions cmake/HandleGeneralOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Qu
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_TIMING "Enable the timing tools (gttic/gttoc)" OFF)
option(GTSAM_HYBRID_TIMING "Enable the timing of hybrid factor graph machinery" OFF)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
Expand Down
1 change: 1 addition & 0 deletions cmake/HandlePrintConfiguration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory San
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ENABLE_TIMING} "Enable timing machinery")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
Expand Down
53 changes: 51 additions & 2 deletions gtsam/base/timing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

namespace gtsam {
namespace internal {


using ChildOrder = FastMap<size_t, std::shared_ptr<TimingOutline>>;

// a static shared_ptr to TimingOutline with nullptr as the pointer
const static std::shared_ptr<TimingOutline> nullTimingOutline;

Expand Down Expand Up @@ -91,7 +93,6 @@ void TimingOutline::print(const std::string& outline) const {
<< n_ << " times, " << wall() << " wall, " << secs() << " children, min: "
<< min() << " max: " << max() << ")\n";
// Order children
typedef FastMap<size_t, std::shared_ptr<TimingOutline> > ChildOrder;
ChildOrder childOrder;
for(const ChildMap::value_type& child: children_) {
childOrder[child.second->myOrder_] = child.second;
Expand All @@ -106,6 +107,54 @@ void TimingOutline::print(const std::string& outline) const {
#endif
}

/* ************************************************************************* */
void TimingOutline::printCsvHeader(bool addLineBreak) const {
#ifdef GTSAM_USE_BOOST_FEATURES
dellaert marked this conversation as resolved.
Show resolved Hide resolved
// Order is (CPU time, number of times, wall time, time + children in seconds,
// min time, max time)
std::cout << label_ + " cpu time (s)" << "," << label_ + " #calls" << ","
<< label_ + " wall time(s)" << "," << label_ + " subtree time (s)"
<< "," << label_ + " min time (s)" << "," << label_ + "max time(s)"
<< ",";
// Order children
ChildOrder childOrder;
for (const ChildMap::value_type& child : children_) {
childOrder[child.second->myOrder_] = child.second;
}
// Print children
for (const ChildOrder::value_type& order_child : childOrder) {
order_child.second->printCsvHeader();
}
if (addLineBreak) {
std::cout << std::endl;
}
std::cout.flush();
#endif
}

/* ************************************************************************* */
void TimingOutline::printCsv(bool addLineBreak) const {
#ifdef GTSAM_USE_BOOST_FEATURES
// Order is (CPU time, number of times, wall time, time + children in seconds,
// min time, max time)
std::cout << self() << "," << n_ << "," << wall() << "," << secs() << ","
<< min() << "," << max() << ",";
// Order children
ChildOrder childOrder;
for (const ChildMap::value_type& child : children_) {
childOrder[child.second->myOrder_] = child.second;
}
// Print children
for (const ChildOrder::value_type& order_child : childOrder) {
order_child.second->printCsv(false);
}
if (addLineBreak) {
std::cout << std::endl;
}
std::cout.flush();
#endif
}

void TimingOutline::print2(const std::string& outline,
const double parentTotal) const {
#ifdef GTSAM_USE_BOOST_FEATURES
Expand Down
31 changes: 31 additions & 0 deletions gtsam/base/timing.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,29 @@ namespace gtsam {
#endif
GTSAM_EXPORT void print(const std::string& outline = "") const;
GTSAM_EXPORT void print2(const std::string& outline = "", const double parentTotal = -1.0) const;

/**
* @brief Print the CSV header.
* Order is
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
*/
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;

/**
* @brief Print the times recursively from parent to child in CSV format.
* For each timing node, the output is
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
*/
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;

GTSAM_EXPORT const std::shared_ptr<TimingOutline>&
child(size_t child, const std::string& label, const std::weak_ptr<TimingOutline>& thisPtr);
GTSAM_EXPORT void tic();
Expand Down Expand Up @@ -268,6 +291,14 @@ inline void tictoc_finishedIteration_() {
inline void tictoc_print_() {
::gtsam::internal::gTimingRoot->print(); }

// print timing in CSV format
inline void tictoc_printCsv_(bool displayHeader = false) {
if (displayHeader) {
::gtsam::internal::gTimingRoot->printCsvHeader(true);
}
::gtsam::internal::gTimingRoot->printCsv(true);
}

// print mean and standard deviation
inline void tictoc_print2_() {
::gtsam::internal::gTimingRoot->print2(); }
Expand Down
3 changes: 3 additions & 0 deletions gtsam/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING

// Whether to enable timing in hybrid factor graph machinery
#cmakedefine01 GTSAM_HYBRID_TIMING

// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

Expand Down
30 changes: 24 additions & 6 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,25 @@ namespace gtsam {
static DecisionTreeFactor ProductAndNormalize(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
#if GTSAM_HYBRID_TIMING
dellaert marked this conversation as resolved.
Show resolved Hide resolved
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
gttoc(product);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

// Max over all the potentials by pretending all keys are frontal:
auto normalizer = product.max(product.size());

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Normalize the product factor to prevent underflow.
product = product / (*normalizer);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}
Expand Down Expand Up @@ -220,9 +230,13 @@ namespace gtsam {
DecisionTreeFactor product = ProductAndNormalize(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
Expand All @@ -232,10 +246,14 @@ namespace gtsam {
sum->keys().end());

// now divide product/sum to get conditional
gttic(divide);
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif

return {conditional, sum};
}
Expand Down
51 changes: 44 additions & 7 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@

#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
Expand Down Expand Up @@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
/**
* @brief Take negative log-values, shift them so that the minimum value is 0,
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
* and then exponentiate to create a TableFactor (not normalized yet!).
*
* @param errors DecisionTree of (unnormalized) errors.
* @return DecisionTreeFactor::shared_ptr
* @return TableFactor::shared_ptr
*/
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
static TableFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/* ************************************************************************ */
Expand Down Expand Up @@ -282,14 +282,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(dc);
#if GTSAM_HYBRID_TIMING
gttic_(ConvertConditionalToTableFactor);
#endif
// Convert DiscreteConditional to TableFactor
auto tdc = std::make_shared<TableFactor>(*dc);
#if GTSAM_HYBRID_TIMING
gttoc_(ConvertConditionalToTableFactor);
#endif
dfg.push_back(tdc);
} else {
throwRuntimeError("discreteElimination", f);
}
}

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif

return {std::make_shared<HybridConditional>(result.first), result.second};
}
Expand Down Expand Up @@ -319,8 +333,19 @@ static std::shared_ptr<Factor> createDiscreteFactor(
}
};

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteBoundaryErrors);
#endif
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
return DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryErrors);
gttic_(DiscreteBoundaryResult);
#endif
auto result = DiscreteFactorFromErrors(discreteSeparator, errors);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteBoundaryResult);
#endif
return result;
}

/* *******************************************************************************/
Expand Down Expand Up @@ -360,12 +385,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// the discrete separator will be *all* the discrete keys.
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);

#if GTSAM_HYBRID_TIMING
gttic_(HybridCollectProductFactor);
#endif
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. Just like any hybrid
// factor, every assignment also has a scalar error, in this case the sum of
// all errors in the graph. This error is assignment-specific and accounts for
// any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor();
#if GTSAM_HYBRID_TIMING
gttoc_(HybridCollectProductFactor);
#endif

// Check if a factor is null
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
Expand Down Expand Up @@ -393,8 +424,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
return {conditional, conditional->negLogConstant(), factor, scalar};
};

#if GTSAM_HYBRID_TIMING
gttic_(HybridEliminate);
#endif
// Perform elimination!
const ResultTree eliminationResults(productFactor, eliminate);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridEliminate);
#endif

// If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
Expand Down
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal(
elimination_ordering, function, std::cref(index));

if (maxNrLeaves) {
#if GTSAM_HYBRID_TIMING
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
gttic_(HybridBayesTreePrune);
#endif
bayesTree->prune(*maxNrLeaves);
#if GTSAM_HYBRID_TIMING
gttoc_(HybridBayesTreePrune);
#endif
}

// Re-add into Bayes tree data structures
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
EXPECT(HybridConditional::CheckInvariants(*result.first, values));

// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
}

/* ************************************************************************* */
Expand Down Expand Up @@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {

// Check the remaining factor for x1
CHECK(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because
Expand Down
3 changes: 1 addition & 2 deletions gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());

// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
}

/****************************************************************************
Expand Down
Loading