diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eeb5dca3f2..0eea8b4bd6 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -24,13 +24,13 @@ #include #include +#include #include #include #include #include #include #include -#include using namespace std; using std::pair; diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 11bea26137..b48e09b03f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -120,13 +120,7 @@ namespace gtsam { static DecisionTreeFactor DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif DecisionTreeFactor product = factors.product(); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif #if GTSAM_HYBRID_TIMING gttic_(DiscreteNormalize); @@ -229,13 +223,7 @@ namespace gtsam { DecisionTreeFactor product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); -#endif DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); -#endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -245,14 +233,8 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteToDiscreteConditional); -#endif auto conditional = std::make_shared(product, *sum, orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteToDiscreteConditional); -#endif return {conditional, sum}; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 4fcd420b19..8be5a8af43 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -255,15 +255,55 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } +/** + * @brief Multiply all the `factors` using the machinery of the TableFactor. + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return TableFactor + */ +static TableFactor TableProduct(const DiscreteFactorGraph &factors) { + // PRODUCT: multiply all factors +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteProduct); +#endif + TableFactor product; + for (auto &&factor : factors) { + if (factor) { + if (auto f = std::dynamic_pointer_cast(factor)) { + product = product * (*f); + } else if (auto dtf = + std::dynamic_pointer_cast(factor)) { + product = product * TableFactor(*dtf); + } + } + } +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteProduct); +#endif + +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteNormalize); +#endif + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); + // Normalize the product factor to prevent underflow. + product = product / (*denominator); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteNormalize); +#endif + + return product; +} + /* ************************************************************************ */ -static std::pair> -discreteElimination(const HybridGaussianFactorGraph &factors, - const Ordering &frontalKeys) { +static DiscreteFactorGraph CollectDiscreteFactors( + const HybridGaussianFactorGraph &factors) { DiscreteFactorGraph dfg; for (auto &f : factors) { if (auto df = dynamic_pointer_cast(f)) { dfg.push_back(df); + } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute a discrete factor from the remaining error. @@ -296,16 +336,48 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } } + return dfg; +} + +/* ************************************************************************ */ +static std::pair> +discreteElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + DiscreteFactorGraph dfg = CollectDiscreteFactors(factors); + #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - // NOTE: This does sum-product. For max-product, use EliminateForMPE. - auto result = EliminateDiscrete(dfg, frontalKeys); + // Check if separator is empty. + // This is the same as checking if the number of frontal variables + // is the same as the number of variables in the DiscreteFactorGraph. + // If the separator is empty, we have a clique of all the discrete variables + // so we can use the TableFactor for efficiency. + if (frontalKeys.size() == dfg.keys().size()) { + // Get product factor + TableFactor product = TableProduct(dfg); + +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteFormDiscreteConditional); +#endif + auto conditional = std::make_shared( + frontalKeys.size(), product.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscrete); + gttoc_(EliminateDiscreteFormDiscreteConditional); #endif - return {std::make_shared(result.first), result.second}; + TableFactor::shared_ptr sum = product.sum(frontalKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscrete); +#endif + + return {std::make_shared(conditional), sum}; + + } else { + // Perform sum-product. + auto result = EliminateDiscrete(dfg, frontalKeys); + return {std::make_shared(result.first), result.second}; + } } /* ************************************************************************ */ diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 14bef5fbb4..698c1bbf6c 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -162,6 +162,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } + /* ************************************************************************* */ int main() { TestResult tr;