Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring #1853

Merged
merged 14 commits into from
Sep 29, 2024
7 changes: 6 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Standard Constructors
/// @{

/** Construct empty Bayes net */
/// Construct empty Bayes net.
HybridBayesNet() = default;

/// Constructor that takes an initializer list of shared pointers.
HybridBayesNet(
std::initializer_list<HybridConditional::shared_ptr> conditionals)
: Base(conditionals) {}

/// @}
/// @name Testable
/// @{
Expand Down
88 changes: 67 additions & 21 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,49 @@
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h>

#include <cstddef>

namespace gtsam {
/* *******************************************************************************/
struct HybridGaussianConditional::ConstructorHelper {
struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
HybridGaussianFactor::FactorValuePairs pairs;
FactorValuePairs pairs;
Conditionals conditionals;
double minNegLogConstant;

/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals)
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair {
using GC = GaussianConditional;
using P = std::vector<std::pair<Vector, double>>;

/// Construct from a vector of mean and sigma pairs, plus extra args.
template <typename... Args>
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
nrFrontals = 1;
minNegLogConstant = std::numeric_limits<double>::infinity();

std::vector<GaussianFactorValuePair> fvs;
std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size());
gcs.reserve(p.size());
for (auto &&[mean, sigma] : p) {
auto gaussianConditional =
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
double value = gaussianConditional->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
fvs.emplace_back(gaussianConditional, value);
gcs.push_back(gaussianConditional);
}

conditionals = Conditionals({mode}, gcs);
pairs = FactorValuePairs({mode}, fvs);
}

/// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
Expand All @@ -51,38 +79,56 @@ struct HybridGaussianConditional::ConstructorHelper {
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
pairs = FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable.");
"HybridGaussianConditional: need at least one frontal variable. "
"Provided conditionals do not contain any frontal variables.");
}
}
};

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(*helper.nrFrontals),
conditionals_(conditionals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {}

/* *******************************************************************************/
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A, parent)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(
DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}

/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
Expand Down
51 changes: 46 additions & 5 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,49 @@ class GTSAM_EXPORT HybridGaussianConditional
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/**
* @brief Constructs a HybridGaussianConditional with means mu_i and
* standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param parameters A vector of pairs (mu_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A × parent + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A The matrix A.
* @param parent The key of the parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A1 The first matrix.
* @param parent1 The key of the first parent variable.
* @param A2 The second matrix.
* @param parent2 The key of the second parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters);

/**
* @brief Construct from multiple discrete keys and conditional tree.
*
Expand Down Expand Up @@ -183,13 +226,11 @@ class GTSAM_EXPORT HybridGaussianConditional

private:
/// Helper struct for private constructor.
struct ConstructorHelper;
struct Helper;

/// Private constructor that uses helper struct above.
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper);
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
Expand Down
35 changes: 14 additions & 21 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,18 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
}

static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
/* *******************************************************************************/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
keysForDiscreteVariables.end());

// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();

// Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which
Expand All @@ -392,7 +392,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}

// Expensive elimination of product factor.
auto result = EliminatePreferCholesky(graph, frontalKeys);
auto result = EliminatePreferCholesky(graph, keys);

// Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty();
Expand Down Expand Up @@ -436,7 +436,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
*/
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
const Ordering &keys) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deviates from the naming scheme of GaussianFactorGraph and other graphs, where it is frontalKeys. Should we strive for consistency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I’ll fix

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no,I went back and looked at GFG and it's always just "keys". So I'll leave this to be consistent. frontal keys is a concept in conditionals, not in elimination. The keys we eliminate will become frontal keys.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay. I must have been mistaken then.

// NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases:
// 1. continuous variable, make a hybrid Gaussian conditional if there are
Expand Down Expand Up @@ -510,20 +510,13 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,

if (only_discrete) {
// Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys);
return discreteElimination(factors, keys);
} else if (only_continuous) {
// Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys);
return continuousElimination(factors, keys);
} else {
// Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());

// Find all discrete keys.
// Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys.
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();

return hybridElimination(factors, frontalKeys, discreteSeparator);
return factors.eliminate(keys);
dellaert marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
GaussianFactorGraphTree assembleGraphTree() const;

/**
* @brief Eliminate the given continuous keys.
*
* @param keys The continuous keys to eliminate.
* @return The conditional on the keys and a factor on the separator.
*/
std::pair<std::shared_ptr<HybridConditional>, std::shared_ptr<Factor>>
eliminate(const Ordering& keys) const;
/// @}

/// Get the GaussianFactorGraph at a given discrete assignment.
Expand Down
7 changes: 3 additions & 4 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/nonlinear/PriorFactor.h>
#include <gtsam/slam/BetweenFactor.h>

#include <vector>

#include "gtsam/linear/GaussianFactor.h"
#include "gtsam/linear/GaussianFactorGraph.h"
#include "gtsam/nonlinear/NonlinearFactor.h"

#pragma once

namespace gtsam {
Expand Down
8 changes: 4 additions & 4 deletions gtsam/hybrid/tests/TinyHybridExample.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ inline HybridBayesNet createHybridBayesNet(size_t num_measurements = 1,
HybridBayesNet bayesNet;

// Create hybrid Gaussian factor z_i = x0 + noise for each measurement.
std::vector<std::pair<Vector, double>> measurementModels{{Z_1x1, 0.5},
{Z_1x1, 3.0}};
for (size_t i = 0; i < num_measurements; i++) {
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
std::vector<GaussianConditional::shared_ptr> conditionals{
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5),
GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3)};
bayesNet.emplace_shared<HybridGaussianConditional>(mode_i, conditionals);
bayesNet.emplace_shared<HybridGaussianConditional>(mode_i, Z(i), I_1x1,
X(0), measurementModels);
}

// Create prior on X(0).
Expand Down
Loading
Loading