From 5c63ac833c56d52b1558d0be876065939394d2c6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 3 Nov 2024 15:32:21 -0500 Subject: [PATCH] use optional DiscreteValues --- gtsam/hybrid/HybridBayesNet.cpp | 33 ++++++++++++++------------------- gtsam/hybrid/HybridBayesNet.h | 15 ++++----------- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e5748366c3..623b82eea7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -198,29 +198,24 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( } /* ************************************************************************* */ -double HybridBayesNet::negLogConstant() const { +double HybridBayesNet::negLogConstant( + const std::optional &discrete) const { double negLogNormConst = 0.0; // Iterate over each conditional. for (auto &&conditional : *this) { - negLogNormConst += conditional->negLogConstant(); - } - return negLogNormConst; -} - -/* ************************************************************************* */ -double HybridBayesNet::negLogConstant(const DiscreteValues &discrete) const { - double negLogNormConst = 0.0; - // Iterate over each conditional. - for (auto &&conditional : *this) { - if (auto gm = conditional->asHybrid()) { - negLogNormConst += gm->choose(discrete)->negLogConstant(); - } else if (auto gc = conditional->asGaussian()) { - negLogNormConst += gc->negLogConstant(); - } else if (auto dc = conditional->asDiscrete()) { - negLogNormConst += dc->choose(discrete)->negLogConstant(); + if (discrete.has_value()) { + if (auto gm = conditional->asHybrid()) { + negLogNormConst += gm->choose(*discrete)->negLogConstant(); + } else if (auto gc = conditional->asGaussian()) { + negLogNormConst += gc->negLogConstant(); + } else if (auto dc = conditional->asDiscrete()) { + negLogNormConst += dc->choose(*discrete)->negLogConstant(); + } else { + throw std::runtime_error( + "Unknown conditional type when computing negLogConstant"); + } } else { - throw std::runtime_error( - "Unknown conditional type when computing negLogConstant"); + negLogNormConst += conditional->negLogConstant(); } } return negLogNormConst; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 451f7f6757..96afb87d6d 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -237,22 +237,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using BayesNet::logProbability; // expose HybridValues version - /** - * @brief Get the negative log of the normalization constant corresponding - * to the joint density represented by this Bayes net. - * - * @return double - */ - double negLogConstant() const; - /** * @brief Get the negative log of the normalization constant - * corresponding to the joint Gaussian density represented by - * this Bayes net indexed by `discrete`. + * corresponding to the joint density represented by this Bayes net. + * Optionally index by `discrete`. * + * @param discrete Optional DiscreteValues * @return double */ - double negLogConstant(const DiscreteValues &discrete) const; + double negLogConstant(const std::optional &discrete) const; /** * @brief Compute normalized posterior P(M|X=x) and return as a tree.