Skip to content

Commit

Permalink
Merge pull request #1925 from borglab/common-discrete-evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 10, 2024
2 parents 3add91d + 8145086 commit 49cac97
Show file tree
Hide file tree
Showing 18 changed files with 44 additions and 33 deletions.
2 changes: 1 addition & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ namespace gtsam {
// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
result.emplace_back(assignment, evaluate(assignment));
}
return result;
}
Expand Down
10 changes: 4 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,14 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/// Calculate probability for given values `x`,
/// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree.
double evaluate(const Assignment<Key>& values) const {
virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override {
return ADT::operator()(values);
}
/// Disambiguate to use DiscreteFactor version. Mainly for wrapper
using DiscreteFactor::operator();

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional
}

/// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const {
virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version
using DiscreteFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
Expand Down
15 changes: 14 additions & 1 deletion gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/**
* @brief Calculate probability for given values.
* Calls specialized evaluation under the hood.
*
* Note: Uses Assignment<Key> as it is the base class of DiscreteValues.
*
* @param values Discrete assignment.
* @return double
*/
virtual double evaluate(const Assignment<Key>& values) const = 0;

/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}

/// Error is just -log(value)
virtual double error(const DiscreteValues& values) const;
Expand Down
3 changes: 2 additions & 1 deletion gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
}

/* ************************************************************************ */
double TableFactor::operator()(const DiscreteValues& values) const {
double TableFactor::evaluate(const Assignment<Key>& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
Expand Down Expand Up @@ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}
// NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
return f;
}
Expand Down
10 changes: 2 additions & 8 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @name Standard Interface
// /// @{

/// Calculate probability for given values `x`,
/// is just look up in TableFactor.
double evaluate(const DiscreteValues& values) const {
return operator()(values);
}

/// Evaluate probability distribution, sugar.
double operator()(const DiscreteValues& values) const override;
/// Evaluate probability distribution, is just look up in TableFactor.
double evaluate(const Assignment<Key>& values) const override;

/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);

DecisionTreeFactor(const gtsam::DiscreteConditional& c);

void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;

size_t cardinality(gtsam::Key j) const;

double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/AllDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
}

/* ************************************************************************* */
double AllDiff::operator()(const DiscreteValues& values) const {
double AllDiff::evaluate(const Assignment<Key>& values) const {
std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
}

/// Calculate value = expensive !
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/BinaryAllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
}

/// Calculate value
double operator()(const DiscreteValues& values) const override {
double evaluate(const Assignment<Key>& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/Domain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ string Domain::base1Str() const {
}

/* ************************************************************************* */
double Domain::operator()(const DiscreteValues& values) const {
double Domain::evaluate(const Assignment<Key>& values) const {
return contains(values.at(key()));
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/Domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
bool contains(size_t value) const { return values_.count(value) > 0; }

/// Calculate value
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/SingleValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
}

/* ************************************************************************* */
double SingleValue::operator()(const DiscreteValues& values) const {
double SingleValue::evaluate(const Assignment<Key>& values) const {
return (double)(values.at(keys_[0]) == value_);
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam_unstable/discrete/SingleValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}

/// Calculate value
double operator()(const DiscreteValues& values) const override;
double evaluate(const Assignment<Key>& values) const override;

/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;
Expand Down
3 changes: 2 additions & 1 deletion python/gtsam/tests/test_DecisionTreeFactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import unittest

from gtsam.utils.test_case import GtsamTestCase

from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
Ordering)
from gtsam.utils.test_case import GtsamTestCase


class TestDecisionTreeFactor(GtsamTestCase):
Expand Down
4 changes: 2 additions & 2 deletions python/gtsam/tests/test_DiscreteBayesTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteValues, Ordering)
DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
Ordering)


class TestDiscreteBayesNet(GtsamTestCase):
Expand Down
3 changes: 2 additions & 1 deletion python/gtsam/tests/test_DiscreteConditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

import unittest

from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase

from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys

# Some DiscreteKeys for binary variables:
A = 0, 2
B = 1, 2
Expand Down
5 changes: 4 additions & 1 deletion python/gtsam/tests/test_DiscreteFactorGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import unittest

import numpy as np
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase

from gtsam import (DecisionTreeFactor, DiscreteConditional,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
Symbol)

OrderingType = Ordering.OrderingType


Expand Down

0 comments on commit 49cac97

Please sign in to comment.