Skip to content

Commit

Permalink
Merge pull request #1848 from borglab/feature/simpler_constructors
Browse files Browse the repository at this point in the history
Simpler HybridGaussianFoo constructors
  • Loading branch information
dellaert authored Sep 27, 2024
2 parents de23ebb + ce45bb6 commit 31eade5
Show file tree
Hide file tree
Showing 26 changed files with 353 additions and 363 deletions.
3 changes: 1 addition & 2 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
}
if (auto p = std::dynamic_pointer_cast<HybridFactor>(factor)) {
} else if (auto p = std::dynamic_pointer_cast<HybridFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
Expand Down
89 changes: 49 additions & 40 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,57 +27,66 @@
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>

#include <cstddef>

namespace gtsam {
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
const HybridGaussianConditional::Conditionals &conditionals) {
auto func = [](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
double value = 0.0;
// Check if conditional is pruned
if (conditional) {
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = conditional->negLogConstant();
/* *******************************************************************************/
struct HybridGaussianConditional::ConstructorHelper {
std::optional<size_t> nrFrontals;
HybridGaussianFactor::FactorValuePairs pairs;
double minNegLogConstant;

/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals)
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair {
double value = 0.0;
if (c) {
if (!nrFrontals.has_value()) {
nrFrontals = c->nrFrontals();
}
value = c->negLogConstant();
minNegLogConstant = std::min(minNegLogConstant, value);
}
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
};
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable.");
}
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
};
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
}
}
};

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
: BaseFactor(discreteParents, helper.pairs),
BaseConditional(*helper.nrFrontals),
conditionals_(conditionals),
negLogConstant_(helper.minNegLogConstant) {}

HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents, GetFactorValuePairs(conditionals)),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate negLogConstant_ as the minimum of the negative-log normalizers of
// the conditionals, by visiting the decision tree:
negLogConstant_ = std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->negLogConstant_ =
std::min(this->negLogConstant_, conditional->negLogConstant());
}
});
}
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}

HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(continuousFrontals, continuousParents,
DiscreteKeys{discreteParent},
Conditionals({discreteParent}, conditionals)) {}

/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
Expand Down Expand Up @@ -222,8 +231,8 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
return {likelihood_m, Cgm_Kgcm};
}
});
return std::make_shared<HybridGaussianFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods);
}

/* ************************************************************************* */
Expand Down
70 changes: 31 additions & 39 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,11 @@ class GTSAM_EXPORT HybridGaussianConditional

private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.

///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_;

/**
* @brief Convert a HybridGaussianConditional of conditionals into
* a DecisionTree of Gaussian factor graphs.
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

/**
* @brief Helper function to get the pruner functor.
*
* @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &discreteProbs);

public:
/// @name Constructors
/// @{
Expand All @@ -93,37 +77,28 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional() = default;

/**
* @brief Construct a new HybridGaussianConditional object.
* @brief Construct from one discrete key and vector of conditionals.
*
* @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents.
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent.
*/
HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/**
* @brief Construct from multiple discrete keys and conditional tree.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals);

/// @}
/// @name Testable
/// @{
Expand Down Expand Up @@ -207,6 +182,23 @@ class GTSAM_EXPORT HybridGaussianConditional
/// @}

private:
/// Helper struct for private constructor.
struct ConstructorHelper;

/// Private constructor that uses helper struct above.
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper);

/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;

//// Get the pruner functor from pruned discrete probabilities.
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &prunedProbabilities);

/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

Expand Down
96 changes: 77 additions & 19 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,20 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>

namespace gtsam {

/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the additional scalar values.
* This is done by storing the value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
static HybridGaussianFactor::Factors augment(
const HybridGaussianFactor::FactorValuePairs &factors) {
/* *******************************************************************************/
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
const FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
HybridGaussianFactor::Factors gaussianFactors;
Factors gaussianFactors;
AlgebraicDecisionTree<Key> valueTree;
std::tie(gaussianFactors, valueTree) = unzip(factors);

Expand Down Expand Up @@ -73,22 +65,88 @@ static HybridGaussianFactor::Factors augment(
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return HybridGaussianFactor::Factors(factors, update);
return Factors(factors, update);
}

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
struct HybridGaussianFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty
Factors factorsTree;

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &factor : factors) {
if (factor && continuousKeys.empty()) {
continuousKeys = factor->keys();
break;
}
}

// Build the DecisionTree from the factor vector
factorsTree = Factors(discreteKeys, factors);
}

ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &pair : factorPairs) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
break;
}
}

// Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs(discreteKeys, factorPairs);
}

ConstructorHelper(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys();
}
});

// Build the FactorValuePairs DecisionTree
pairs = factorPairs;
}
};

/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}

HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}

HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
: Base(continuousKeys, discreteKeys), factors_(augment(factors)) {}
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}

/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false;

// This will return false if either factors_ is empty or e->factors_ is empty,
// but not if both are empty or both are not empty:
// This will return false if either factors_ is empty or e->factors_ is
// empty, but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;

// Check the base and the factors:
Expand Down
Loading

0 comments on commit 31eade5

Please sign in to comment.