Skip to content

Commit

Permalink
Merge pull request #1953 from borglab/discrete-table-conditional
Browse files Browse the repository at this point in the history
TableFactor and TableDistribution
  • Loading branch information
varunagrawal authored Jan 8, 2025
2 parents edef8c8 + 3ecc232 commit 8cf2123
Show file tree
Hide file tree
Showing 28 changed files with 615 additions and 156 deletions.
3 changes: 3 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ namespace gtsam {

AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}

/// Constructor which accepts root pointer
AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {}

// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}

Expand Down
20 changes: 20 additions & 0 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature)
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// If the root is a nullptr, we have a TableDistribution
// TODO(Varun) Revisit this hack after RSS2025 submission
if (!other.root_) {
DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor());
return dc * (*this);
}

// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
Expand Down Expand Up @@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->operator()(x.discrete());
}

/* ************************************************************************* */
DiscreteFactor::shared_ptr DiscreteConditional::max(
const Ordering& keys) const {
return BaseFactor::max(keys);
}

/* ************************************************************************* */
void DiscreteConditional::prune(size_t maxNrAssignments) {
// Get as DiscreteConditional so the probabilities are normalized
DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments));
this->root_ = pruned.root_;
}

/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; }

Expand Down
14 changes: 13 additions & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const DiscreteValues& parentsValues) const;
virtual size_t sample(const DiscreteValues& parentsValues) const;

/// Single parent version.
size_t sample(size_t parent_value) const;
Expand All @@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional
*/
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;

/**
* @brief Create new factor by maximizing over all
* values with the same separator.
*
* @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;

/// @}
/// @name Advanced Interface
/// @{
Expand Down Expand Up @@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional
*/
double negLogConstant() const override;

/// Prune the conditional
virtual void prune(size_t maxNrAssignments);

/// @}

protected:
Expand Down
22 changes: 5 additions & 17 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,18 @@ namespace gtsam {
// }
// }

/**
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
static DiscreteFactor::shared_ptr DiscreteProduct(
const DiscreteFactorGraph& factors) {
/* ************************************************************************ */
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
// PRODUCT: multiply all factors
gttic(product);
DiscreteFactor::shared_ptr product = factors.product();
DiscreteFactor::shared_ptr product = this->product();
gttoc(product);

#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product->max(product->size());

// Normalize the product factor to prevent underflow.
product = product->operator/(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif

return product;
}
Expand All @@ -151,7 +139,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();

// max out frontals, this is the factor on the separator
gttic(max);
Expand Down Expand Up @@ -229,7 +217,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DiscreteFactor::shared_ptr product = DiscreteProduct(factors);
DiscreteFactor::shared_ptr product = factors.scaledProduct();

// sum out frontals, this is the factor on the separator
gttic(sum);
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** return product of all factors as a single factor */
DiscreteFactor::shared_ptr product() const;

/**
* @brief Return product of all `factors` as a single factor,
* which is scaled by the max value to prevent underflow
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DiscreteFactor::shared_ptr
*/
DiscreteFactor::shared_ptr scaledProduct() const;

/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values
Expand Down
174 changes: 174 additions & 0 deletions gtsam/discrete/TableDistribution.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file TableDistribution.cpp
* @date Dec 22, 2024
* @author Varun Agrawal
*/

#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>

using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
namespace gtsam {

/// Normalize sparse_table
static Eigen::SparseVector<double> normalizeSparseTable(
const Eigen::SparseVector<double>& sparse_table) {
return sparse_table / sparse_table.sum();
}

/* ************************************************************************** */
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)),
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}

/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::vector<double>& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}

/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,
const std::string& potentials)
: BaseConditional(keys.size(), keys, ADT(nullptr)),
table_(TableFactor(
keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) {
}

/* ************************************************************************** */
void TableDistribution::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
cout << "):\n";
table_.print("", formatter);
cout << endl;
}

/* ************************************************************************** */
bool TableDistribution::equals(const DiscreteFactor& other, double tol) const {
auto dtc = dynamic_cast<const TableDistribution*>(&other);
if (!dtc) {
return false;
} else {
const DiscreteConditional& f(
static_cast<const DiscreteConditional&>(other));
return table_.equals(dtc->table_, tol) &&
DiscreteConditional::BaseConditional::equals(f, tol);
}
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const {
return table_.sum(nrFrontals);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const {
return table_.sum(keys);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const {
return table_.max(nrFrontals);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
return table_.max(keys);
}

/* ****************************************************************************/
DiscreteFactor::shared_ptr TableDistribution::operator/(
const DiscreteFactor::shared_ptr& f) const {
return table_ / f;
}

/* ************************************************************************ */
DiscreteValues TableDistribution::argmax() const {
uint64_t maxIdx = 0;
double maxValue = 0.0;

Eigen::SparseVector<double> sparseTable = table_.sparseTable();

for (SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
maxValue = it.value();
}
}

return table_.findAssignments(maxIdx);
}

/* ****************************************************************************/
void TableDistribution::prune(size_t maxNrAssignments) {
table_ = table_.prune(maxNrAssignments);
}

/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator

DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
}

// Get the correct conditional distribution: P(F|S=parentsValues)
TableFactor pFS = table_.choose(parentsValues, parentsKeys);

// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"TableDistribution::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}

} // namespace gtsam
Loading

0 comments on commit 8cf2123

Please sign in to comment.