Skip to content

Commit

Permalink
Merge pull request #1853 from borglab/feature/bougie_constructors
Browse files Browse the repository at this point in the history
More refactoring
  • Loading branch information
dellaert authored Sep 29, 2024
2 parents 7a9300c + 2cf2100 commit ffb2829
Show file tree
Hide file tree
Showing 15 changed files with 314 additions and 273 deletions.
7 changes: 6 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors
/// @{

/** Construct empty Bayes net */
/// Construct empty Bayes net.
HybridBayesNet() = default;

/// Constructor that takes an initializer list of shared pointers.
HybridBayesNet(
std::initializer_list<HybridConditional::shared_ptr> conditionals)
: Base(conditionals) {}

/// @}
/// @name Testable
/// @{
Expand Down
88 changes: 67 additions & 21 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,49 @@
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h>

#include <cstddef>

namespace gtsam {
/* *******************************************************************************/
struct HybridGaussianConditional::ConstructorHelper {
struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
HybridGaussianFactor::FactorValuePairs pairs;
FactorValuePairs pairs;
Conditionals conditionals;
double minNegLogConstant;

/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals)
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair {
using GC = GaussianConditional;
using P = std::vector<std::pair<Vector, double>>;

/// Construct from a vector of mean and sigma pairs, plus extra args.
template <typename... Args>
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
nrFrontals = 1;
minNegLogConstant = std::numeric_limits<double>::infinity();

std::vector<GaussianFactorValuePair> fvs;
std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size());
gcs.reserve(p.size());
for (auto &&[mean, sigma] : p) {
auto gaussianConditional =
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
double value = gaussianConditional->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
fvs.emplace_back(gaussianConditional, value);
gcs.push_back(gaussianConditional);
}

conditionals = Conditionals({mode}, gcs);
pairs = FactorValuePairs({mode}, fvs);
}

/// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
Expand All @@ -51,38 +79,56 @@ struct HybridGaussianConditional::ConstructorHelper {
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
pairs = FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable.");
"HybridGaussianConditional: need at least one frontal variable. "
"Provided conditionals do not contain any frontal variables.");
}
}
};

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(*helper.nrFrontals),
conditionals_(conditionals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A, parent)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}

/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
Expand Down
51 changes: 46 additions & 5 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,49 @@ class GTSAM_EXPORT HybridGaussianConditional
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/**
* @brief Constructs a HybridGaussianConditional with means mu_i and
* standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param parameters A vector of pairs (mu_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A × parent + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A The matrix A.
* @param parent The key of the parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A1 The first matrix.
* @param parent1 The key of the first parent variable.
* @param A2 The second matrix.
* @param parent2 The key of the second parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Construct from multiple discrete keys and conditional tree.
*
Expand Down Expand Up @@ -183,13 +226,11 @@ class GTSAM_EXPORT HybridGaussianConditional

private:
/// Helper struct for private constructor.
struct ConstructorHelper;
struct Helper;

/// Private constructor that uses helper struct above.
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper);
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
Expand Down
35 changes: 14 additions & 21 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,18 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
}

static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
/* *******************************************************************************/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
keysForDiscreteVariables.end());

// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();

// Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which
Expand All @@ -392,7 +392,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}

// Expensive elimination of product factor.
auto result = EliminatePreferCholesky(graph, frontalKeys);
auto result = EliminatePreferCholesky(graph, keys);

// Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty();
Expand Down Expand Up @@ -436,7 +436,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
*/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
const Ordering &keys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases:
// 1. continuous variable, make a hybrid Gaussian conditional if there are
Expand Down Expand Up @@ -510,20 +510,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,

if (only_discrete) {
// Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys);
return discreteElimination(factors, keys);
} else if (only_continuous) {
// Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys);
return continuousElimination(factors, keys);
} else {
// Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());

// Find all discrete keys.
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();

return hybridElimination(factors, frontalKeys, discreteSeparator);
return factors.eliminate(keys);
}
}

Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
GaussianFactorGraphTree assembleGraphTree() const;

/**
* @brief Eliminate the given continuous keys.
*
* @param keys The continuous keys to eliminate.
* @return The conditional on the keys and a factor on the separator.
*/
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
eliminate(const Ordering& keys) const;
/// @}

/// Get the GaussianFactorGraph at a given discrete assignment.
Expand Down
7 changes: 3 additions & 4 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/nonlinear/PriorFactor.h>
#include <gtsam/slam/BetweenFactor.h>

#include <vector>

#include "gtsam/linear/GaussianFactor.h"
#include "gtsam/linear/GaussianFactorGraph.h"
#include "gtsam/nonlinear/NonlinearFactor.h"

#pragma once

namespace gtsam {
Expand Down
8 changes: 4 additions & 4 deletions gtsam/hybrid/tests/TinyHybridExample.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ inline HybridBayesNet createHybridBayesNet(size_t num_measurements = 1,
HybridBayesNet bayesNet;

// Create hybrid Gaussian factor z_i = x0 + noise for each measurement.
std::vector<std::pair<Vector, double>> measurementModels{{Z_1x1, 0.5},
{Z_1x1, 3.0}};
for (size_t i = 0; i < num_measurements; i++) {
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
std::vector<GaussianConditional::shared_ptr> conditionals{
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5),
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3)};
bayesNet.emplace_shared<HybridGaussianConditional>(mode_i, conditionals);
bayesNet.emplace_shared<HybridGaussianConditional>(mode_i, Z(i), I_1x1,
X(0), measurementModels);
}

// Create prior on X(0).
Expand Down
Loading

0 comments on commit ffb2829

Please sign in to comment.