Skip to content

Commit

Permalink
Merge pull request #1802 from borglab/working-hybrid
Browse files Browse the repository at this point in the history
Working Hybrid
  • Loading branch information
varunagrawal authored Sep 5, 2024
2 parents c3842bb + 8b04d9b commit 232fa02
Show file tree
Hide file tree
Showing 16 changed files with 597 additions and 77 deletions.
43 changes: 38 additions & 5 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {
Expand Down Expand Up @@ -86,7 +87,22 @@ GaussianFactorGraphTree GaussianMixture::add(

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm =
this->logConstant_ - gc->logNormalizationConstant();
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
// We add a constant factor which will be used when computing
// the probability of the discrete variables.
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
return GaussianFactorGraph{gc, constantFactor};
}
}
return GaussianFactorGraph{gc};
};
return {conditionals_, wrap};
Expand Down Expand Up @@ -145,6 +161,8 @@ void GaussianMixture::print(const std::string &s,
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << "\n";
std::cout << " logNormalizationConstant: " << logConstant_ << "\n"
<< std::endl;
conditionals_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
Expand Down Expand Up @@ -312,12 +330,28 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc);
}

/* ************************************************************************* */
double GaussianMixture::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
Expand All @@ -327,8 +361,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
return conditionalError(conditional, values.continuous());
}

/* *******************************************************************************/
Expand Down
6 changes: 5 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into
* @brief Convert a GaussianMixture of conditionals into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
Expand Down Expand Up @@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << (s.empty() ? "" : s + "\n");
std::cout << "GaussianMixtureFactor" << std::endl;
HybridFactor::print("", formatter);
std::cout << "{\n";
if (factors_.empty()) {
std::cout << " empty" << std::endl;
Expand All @@ -64,7 +66,7 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const sharedFactor &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
if (gf) {
gf->print("", formatter);
return rd.str();
} else {
Expand Down Expand Up @@ -117,6 +119,5 @@ double GaussianMixtureFactor::error(const HybridValues &values) const {
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/

} // namespace gtsam
9 changes: 4 additions & 5 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture
* density.
* @param factors The decision tree of Gaussian factors stored
* as the mixture density.
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
Expand All @@ -107,9 +107,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {

bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

void print(
const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
void print(const std::string &s = "", const KeyFormatter &formatter =
DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
Expand Down
7 changes: 4 additions & 3 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,16 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn;
DiscreteFactorGraph discrete_fg;

for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete());
discrete_fg.push_back(conditional->asDiscrete());
}
}

// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
DiscreteValues mpe = discrete_fg.optimize();

// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GTSAM_EXPORT HybridConditional
public Conditional<HybridFactor, HybridConditional> {
public:
// typedefs needed to play nice with gtsam
typedef HybridConditional This; ///< Typedef to this class
typedef HybridConditional This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This>
Expand Down Expand Up @@ -185,7 +185,7 @@ class GTSAM_EXPORT HybridConditional
* Return the log normalization constant.
* Note this is 0.0 for discrete and hybrid conditionals, but depends
* on the continuous parameters for Gaussian conditionals.
*/
*/
double logNormalizationConstant() const override;

/// Return the probability (or density) of the underlying conditional.
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file HybridFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/

#pragma once
Expand Down
9 changes: 0 additions & 9 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,6 @@ KeySet HybridFactorGraph::discreteKeySet() const {
return keys;
}

/* ************************************************************************* */
std::unordered_map<Key, DiscreteKey> HybridFactorGraph::discreteKeyMap() const {
std::unordered_map<Key, DiscreteKey> result;
for (const DiscreteKey& k : discreteKeys()) {
result[k.first] = k;
}
return result;
}

/* ************************************************************************* */
const KeySet HybridFactorGraph::continuousKeySet() const {
KeySet keys;
Expand Down
7 changes: 2 additions & 5 deletions gtsam/hybrid/HybridFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using SharedFactor = std::shared_ptr<Factor>;
class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> {
public:
using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class
using This = HybridFactorGraph; ///< this class
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This

using Values = gtsam::Values; ///< backwards compatibility
Expand Down Expand Up @@ -66,12 +66,9 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> {
/// Get all the discrete keys in the factor graph.
std::set<DiscreteKey> discreteKeys() const;

/// Get all the discrete keys in the factor graph, as a set.
/// Get all the discrete keys in the factor graph, as a set of Keys.
KeySet discreteKeySet() const;

/// Get a map from Key to corresponding DiscreteKey.
std::unordered_map<Key, DiscreteKey> discreteKeyMap() const;

/// Get all the continuous keys in the factor graph.
const KeySet continuousKeySet() const;

Expand Down
Loading

0 comments on commit 232fa02

Please sign in to comment.