Skip to content

Commit

Permalink
Merge pull request #1894 from borglab/check-isam
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Nov 5, 2024
2 parents 2087d3f + e306240 commit 2bd2d82
Show file tree
Hide file tree
Showing 14 changed files with 306 additions and 255 deletions.
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {

/// Getter for GaussianFactor decision tree
const FactorValuePairs &factors() const { return factors_; }

/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
Expand Down
11 changes: 11 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,15 @@ GaussianFactorGraph HybridGaussianFactorGraph::choose(
return gfg;
}

/* ************************************************************************ */
DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const {
DiscreteFactorGraph dfg;
for (auto &&f : factors_) {
if (auto discreteFactor = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(discreteFactor);
}
}
return dfg;
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#pragma once

#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
Expand Down Expand Up @@ -254,6 +255,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
return choose(assignment);
}

/**
* @brief Helper method to get all the discrete factors
* as a DiscreteFactorGraph.
*
* @return DiscreteFactorGraph
*/
DiscreteFactorGraph discreteFactors() const;
};

// traits
Expand Down
38 changes: 37 additions & 1 deletion gtsam/hybrid/HybridNonlinearFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* @date Sep 12, 2024
*/

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
Expand Down Expand Up @@ -184,6 +185,11 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
[continuousValues](
const std::pair<sharedFactor, double>& f) -> GaussianFactorValuePair {
auto [factor, val] = f;
// Check if valid factor. If not, return null and infinite error.
if (!factor) {
return {nullptr, std::numeric_limits<double>::infinity()};
}

if (auto gaussian = std::dynamic_pointer_cast<noiseModel::Gaussian>(
factor->noiseModel())) {
return {factor->linearize(continuousValues),
Expand All @@ -202,4 +208,34 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
linearized_factors);
}

} // namespace gtsam
/* *******************************************************************************/
HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
const DecisionTreeFactor& discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));

// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);

// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
auto pruner =
[&](const Assignment<Key>& choices,
const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair {
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
};

FactorValuePairs prunedFactors = factors().apply(pruner);
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
}

} // namespace gtsam
7 changes: 7 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {

/// @}

/// Getter for NonlinearFactor decision tree
const FactorValuePairs& factors() const { return factors_; }

/// Linearize specific nonlinear factors based on the assignment in
/// discreteValues.
GaussianFactor::shared_ptr linearize(
Expand All @@ -176,6 +179,10 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
std::shared_ptr<HybridGaussianFactor> linearize(
const Values& continuousValues) const;

/// Prune this factor based on the discrete probabilities.
HybridNonlinearFactor::shared_ptr prune(
const DecisionTreeFactor& discreteProbs) const;

private:
/// Helper struct to assist private constructor below.
struct ConstructorHelper;
Expand Down
15 changes: 14 additions & 1 deletion gtsam/hybrid/HybridNonlinearISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
#include <gtsam/inference/Ordering.h>

Expand All @@ -39,7 +40,6 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
if (newFactors.size() > 0) {
// Reorder and relinearize every reorderInterval updates
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
// TODO(Varun) Re-linearization doesn't take into account pruning
reorderRelinearize();
reorderCounter_ = 0;
}
Expand All @@ -65,8 +65,21 @@ void HybridNonlinearISAM::reorderRelinearize() {
// Obtain the new linearization point
const Values newLinPoint = estimate();

auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete());

isam_.clear();

// Prune nonlinear factors based on discrete conditional probabilities
HybridNonlinearFactorGraph pruned_factors;
for (auto&& factor : factors_) {
if (auto nf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
pruned_factors.push_back(nf->prune(discreteProbs));
} else {
pruned_factors.push_back(factor);
}
}
factors_ = pruned_factors;

// Just recreate the whole BayesTree
// TODO: allow for constrained ordering here
// TODO: decouple re-linearization and reordering to avoid
Expand Down
6 changes: 4 additions & 2 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ using symbol_shorthand::X;
* @return HybridGaussianFactorGraph::shared_ptr
*/
inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
size_t K, std::function<Key(int)> x = X, std::function<Key(int)> m = M) {
size_t K, std::function<Key(int)> x = X, std::function<Key(int)> m = M,
const std::string &transitionProbabilityTable = "0 1 1 3") {
HybridGaussianFactorGraph hfg;

hfg.add(JacobianFactor(x(1), I_3x3, Z_3x1));
Expand All @@ -68,7 +69,8 @@ inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
hfg.add(HybridGaussianFactor({m(k), 2}, components));

if (k > 1) {
hfg.add(DecisionTreeFactor({{m(k - 1), 2}, {m(k), 2}}, "0 1 1 3"));
hfg.add(DecisionTreeFactor({{m(k - 1), 2}, {m(k), 2}},
transitionProbabilityTable));
}
}

Expand Down
32 changes: 6 additions & 26 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ TEST(HybridGaussianFactorGraph,
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));

hfg.add(DecisionTreeFactor(m1, {2, 8}));
// TODO(Varun) Adding extra discrete variable not connected to continuous
// variable throws segfault
// hfg.add(DecisionTreeFactor({m1, m2, "1 2 3 4"));
hfg.add(DecisionTreeFactor({m1, m2}, "1 2 3 4"));

HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();

Expand Down Expand Up @@ -176,7 +174,7 @@ TEST(HybridGaussianFactorGraph, Switching) {
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
// X(5),
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
KeyVector ordering;
Ordering ordering;

{
std::vector<int> naturalX(N);
Expand All @@ -187,10 +185,6 @@ TEST(HybridGaussianFactorGraph, Switching) {

auto [ndX, lvls] = makeBinaryOrdering(ordX);
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
// TODO(dellaert): this has no effect!
for (auto& l : lvls) {
l = -l;
}
}
{
std::vector<int> naturalC(N - 1);
Expand All @@ -199,14 +193,11 @@ TEST(HybridGaussianFactorGraph, Switching) {
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
[](int x) { return M(x); });

// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
}
auto ordering_full = Ordering(ordering);

const auto [hbt, remaining] =
hfg->eliminatePartialMultifrontal(ordering_full);
const auto [hbt, remaining] = hfg->eliminatePartialMultifrontal(ordering);

// 12 cliques in the bayes tree and 0 remaining variables to eliminate.
EXPECT_LONGS_EQUAL(12, hbt->size());
Expand All @@ -230,7 +221,7 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
// X(5),
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
KeyVector ordering;
Ordering ordering;

{
std::vector<int> naturalX(N);
Expand All @@ -241,10 +232,6 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {

auto [ndX, lvls] = makeBinaryOrdering(ordX);
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
// TODO(dellaert): this has no effect!
for (auto& l : lvls) {
l = -l;
}
}
{
std::vector<int> naturalC(N - 1);
Expand All @@ -257,10 +244,8 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
}
auto ordering_full = Ordering(ordering);

const auto [hbt, remaining] =
hfg->eliminatePartialMultifrontal(ordering_full);
const auto [hbt, remaining] = hfg->eliminatePartialMultifrontal(ordering);

auto new_fg = makeSwitchingChain(12);
auto isam = HybridGaussianISAM(*hbt);
Expand Down Expand Up @@ -460,12 +445,7 @@ TEST(HybridBayesTree, Optimize) {
const auto [hybridBayesNet, remainingFactorGraph] =
s.linearizedFactorGraph().eliminatePartialSequential(ordering);

DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}
DiscreteFactorGraph dfg = remainingFactorGraph->discreteFactors();

// Add the probabilities for each branch
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
Expand Down
Loading

0 comments on commit 2bd2d82

Please sign in to comment.