Skip to content

Commit

Permalink
Merge pull request #1837 from borglab/improved-api-2
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Sep 22, 2024
2 parents 33c4482 + 9f9032f commit e52973b
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 118 deletions.
16 changes: 16 additions & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;

/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// Virtual class to compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;

/// @}

private:
Expand Down
34 changes: 0 additions & 34 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,40 +323,6 @@ AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc);
}

/* ************************************************************************* */
double HybridGaussianConditional::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
-logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianConditional::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double HybridGaussianConditional::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditionalError(conditional, values.continuous());
}

/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {
Expand Down
47 changes: 3 additions & 44 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
Expand Down Expand Up @@ -174,43 +174,6 @@ class GTSAM_EXPORT HybridGaussianConditional
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this hybrid Gaussian conditional.
*
* This requires some care, as different components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute error of the HybridGaussianConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
*
Expand Down Expand Up @@ -241,10 +204,6 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
Expand Down
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,36 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}

/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
return potentiallyPrunedComponentError(gf, values.continuous());
}

} // namespace gtsam
6 changes: 5 additions & 1 deletion gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
const VectorValues &continuousValues) const override;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand All @@ -186,6 +186,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @}

private:
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
Expand Down
36 changes: 10 additions & 26 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(

// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
// We take negative of the logNormalizationConstant `log(k)`
// to get `log(1/k) = log(\sqrt{|2πΣ|})`.
return -factor->error(kEmpty) - conditional->logNormalizationConstant();
};

Expand Down Expand Up @@ -539,36 +539,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;

auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}

if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
}
}

return error_tree;
}

Expand Down
7 changes: 7 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;

/// HybridFactor method implementation. Should not be used.
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error does not take VectorValues.");
}

public:
HybridNonlinearFactor() = default;

Expand Down
4 changes: 4 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/
std::shared_ptr<HybridGaussianFactorGraph> linearize(
const Values& continuousValues) const;

/// Expose error(const HybridValues&) method.
using Base::error;

/// @}
};

Expand Down
49 changes: 49 additions & 0 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,55 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
EXPECT(assert_equal(expected, errorTree, 1e-9));
}

/* ****************************************************************************/
// Test hybrid gaussian factor graph errorTree during
// incremental operation
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
Switching s(4);

HybridGaussianFactorGraph graph;
graph.push_back(s.linearizedFactorGraph.at(0)); // f(X0)
graph.push_back(s.linearizedFactorGraph.at(1)); // f(X0, X1, M0)
graph.push_back(s.linearizedFactorGraph.at(2)); // f(X1, X2, M1)
graph.push_back(s.linearizedFactorGraph.at(4)); // f(X1)
graph.push_back(s.linearizedFactorGraph.at(5)); // f(X2)
graph.push_back(s.linearizedFactorGraph.at(7)); // f(M0)
graph.push_back(s.linearizedFactorGraph.at(8)); // f(M0, M1)

HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.errorTree(delta.continuous());

std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);

// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));

graph = HybridGaussianFactorGraph();
graph.push_back(*hybridBayesNet);
graph.push_back(s.linearizedFactorGraph.at(3)); // f(X2, X3, M2)
graph.push_back(s.linearizedFactorGraph.at(6)); // f(X3)

hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(7, hybridBayesNet->size());

delta = hybridBayesNet->optimize();
auto error_tree2 = graph.errorTree(delta.continuous());

discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);

// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
}

/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
Expand Down
Loading

0 comments on commit e52973b

Please sign in to comment.