Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize discrete elimination in Hybrid #1955

Merged
merged 19 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>

using namespace std;
using std::pair;
Expand Down
18 changes: 0 additions & 18 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,7 @@ namespace gtsam {
static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
DecisionTreeFactor product = factors.product();
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
Expand Down Expand Up @@ -229,13 +223,7 @@ namespace gtsam {
DecisionTreeFactor product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteSum);
#endif
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif

// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
Expand All @@ -245,14 +233,8 @@ namespace gtsam {
sum->keys().end());

// now divide product/sum to get conditional
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteToDiscreteConditional);
#endif
auto conditional =
std::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteToDiscreteConditional);
#endif

return {conditional, sum};
}
Expand Down
86 changes: 79 additions & 7 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}

/**
* @brief Multiply all the `factors` using the machinery of the TableFactor.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
static TableFactor TableProduct(const DiscreteFactorGraph &factors) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be in another PR, but maybe this should be a method of DiscreteFactorGraph

// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = product * TableFactor(*dtf);
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
static DiscreteFactorGraph CollectDiscreteFactors(
const HybridGaussianFactorGraph &factors) {
DiscreteFactorGraph dfg;

for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df);

} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute a discrete factor from the remaining error.
Expand Down Expand Up @@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
}
}

return dfg;
}

/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg = CollectDiscreteFactors(factors);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscrete);
#endif
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
// Check if separator is empty.
// This is the same as checking if the number of frontal variables
// is the same as the number of variables in the DiscreteFactorGraph.
// If the separator is empty, we have a clique of all the discrete variables
// so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor
TableFactor product = TableProduct(dfg);

#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto conditional = std::make_shared<DiscreteConditional>(
frontalKeys.size(), product.toDecisionTreeFactor());
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif

return {std::make_shared<HybridConditional>(result.first), result.second};
TableFactor::shared_ptr sum = product.sum(frontalKeys);
dellaert marked this conversation as resolved.
Show resolved Hide resolved
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscrete);
#endif

return {std::make_shared<HybridConditional>(conditional), sum};

} else {
// Perform sum-product.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {std::make_shared<HybridConditional>(result.first), result.second};
}
}

/* ************************************************************************ */
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading