diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 074534b8d5..3cb3bba65d 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const { "HybridConditional::error: conditional type not handled"); } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridConditional::errorTree( + const VectorValues &values) const { + if (auto gc = asGaussian()) { + return AlgebraicDecisionTree(gc->error(values)); + } + if (auto gm = asHybrid()) { + return gm->errorTree(values); + } + if (auto dc = asDiscrete()) { + return AlgebraicDecisionTree(0.0); + } + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); +} + /* ************************************************************************ */ double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index f44ee2bf99..51eeeb5bb6 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional /// Return the error of the underlying conditional. double error(const HybridValues& values) const override; + /** + * @brief Compute error of the HybridConditional as a tree. + * + * @param continuousValues The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with the same keys + * as the conditionals involved, and leaf values as the error. + */ + AlgebraicDecisionTree errorTree( + const VectorValues& values) const override; + /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index ad29dfdca9..fc91e08389 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// Return only the continuous keys for this factor. const KeyVector &continuousKeys() const { return continuousKeys_; } + /// Virtual class to compute tree of linear errors. + virtual AlgebraicDecisionTree errorTree( + const VectorValues &values) const = 0; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 0b1dc53377..fb943366cb 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -323,40 +323,6 @@ AlgebraicDecisionTree HybridGaussianConditional::logProbability( return DecisionTree(conditionals_, probFunc); } -/* ************************************************************************* */ -double HybridGaussianConditional::conditionalError( - const GaussianConditional::shared_ptr &conditional, - const VectorValues &continuousValues) const { - // Check if valid pointer - if (conditional) { - return conditional->error(continuousValues) + // - -logConstant_ - conditional->logNormalizationConstant(); - } else { - // If not valid, pointer, it means this conditional was pruned, - // so we return maximum error. - // This way the negative exponential will give - // a probability value close to 0.0. - return std::numeric_limits::max(); - } -} - -/* *******************************************************************************/ -AlgebraicDecisionTree HybridGaussianConditional::errorTree( - const VectorValues &continuousValues) const { - auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { - return conditionalError(conditional, continuousValues); - }; - DecisionTree error_tree(conditionals_, errorFunc); - return error_tree; -} - -/* *******************************************************************************/ -double HybridGaussianConditional::error(const HybridValues &values) const { - // Directly index to get the conditional, no need to build the whole tree. - auto conditional = conditionals_(values.discrete()); - return conditionalError(conditional, values.continuous()); -} - /* *******************************************************************************/ double HybridGaussianConditional::logProbability( const HybridValues &values) const { diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 72a9994729..4a5fdcc89e 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional const Conditionals &conditionals); /** - * @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian - * conditionals. The DecisionTree-based constructor is preferred over this - * one. + * @brief Make a Hybrid Gaussian Conditional from + * a vector of Gaussian conditionals. + * The DecisionTree-based constructor is preferred over this one. * * @param continuousFrontals The continuous frontal variables * @param continuousParents The continuous parent variables @@ -174,43 +174,6 @@ class GTSAM_EXPORT HybridGaussianConditional AlgebraicDecisionTree logProbability( const VectorValues &continuousValues) const; - /** - * @brief Compute the error of this hybrid Gaussian conditional. - * - * This requires some care, as different components may have - * different normalization constants. Let's consider p(x|y,m), where m is - * discrete. We need the error to satisfy the invariant: - * - * error(x;y,m) = K - log(probability(x;y,m)) - * - * For all x,y,m. But note that K, the (log) normalization constant defined - * in Conditional.h, should not depend on x, y, or m, only on the parameters - * of the density. Hence, we delegate to the underlying Gaussian - * conditionals, indexed by m, which do satisfy: - * - * log(probability_m(x;y)) = K_m - error_m(x;y) - * - * We resolve by having K == max(K_m) and - * - * error(x;y,m) = error_m(x;y) + K - K_m - * - * which also makes error(x;y,m) >= 0 for all x,y,m. - * - * @param values Continuous values and discrete assignment. - * @return double - */ - double error(const HybridValues &values) const override; - - /** - * @brief Compute error of the HybridGaussianConditional as a tree. - * - * @param continuousValues The continuous VectorValues. - * @return AlgebraicDecisionTree A decision tree on the discrete keys - * only, with the leaf values as the error for each assignment. - */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; - /** * @brief Compute the logProbability of this hybrid Gaussian conditional. * @@ -241,10 +204,6 @@ class GTSAM_EXPORT HybridGaussianConditional /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; - /// Helper method to compute the error of a conditional. - double conditionalError(const GaussianConditional::shared_ptr &conditional, - const VectorValues &continuousValues) const; - #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 88c557672d..d5773590bb 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -151,12 +151,26 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() return {factors_, wrap}; } +/* *******************************************************************************/ +double HybridGaussianFactor::potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &values) const { + // Check if valid pointer + if (gf) { + return gf->error(values); + } else { + // If not valid, pointer, it means this component was pruned, + // so we return maximum error. + // This way the negative exponential will give + // a probability value close to 0.0. + return std::numeric_limits::max(); + } +} /* *******************************************************************************/ AlgebraicDecisionTree HybridGaussianFactor::errorTree( const VectorValues &continuousValues) const { // functor to convert from sharedFactor to double error value. - auto errorFunc = [&continuousValues](const sharedFactor &gf) { - return gf->error(continuousValues); + auto errorFunc = [this, &continuousValues](const sharedFactor &gf) { + return this->potentiallyPrunedComponentError(gf, continuousValues); }; DecisionTree error_tree(factors_, errorFunc); return error_tree; @@ -164,8 +178,9 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( /* *******************************************************************************/ double HybridGaussianFactor::error(const HybridValues &values) const { + // Directly index to get the component, no need to build the whole tree. const sharedFactor gf = factors_(values.discrete()); - return gf->error(values.continuous()); + return potentiallyPrunedComponentError(gf, values.continuous()); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 925a37e041..817e54e562 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -166,7 +166,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { * as the factors involved, and leaf values as the error. */ AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; + const VectorValues &continuousValues) const override; /** * @brief Compute the log-likelihood, including the log-normalizing constant. @@ -186,6 +186,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// @} private: + /// Helper method to compute the error of a component. + double potentiallyPrunedComponentError( + const sharedFactor &gf, const VectorValues &continuousValues) const; + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 28a0c446fd..a6fe955eb3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -329,8 +329,8 @@ static std::shared_ptr createDiscreteFactor( // Logspace version of: // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); - // We take negative of the logNormalizationConstant `log(1/k)` - // to get `log(k)`. + // We take negative of the logNormalizationConstant `log(k)` + // to get `log(1/k) = log(\sqrt{|2πΣ|})`. return -factor->error(kEmpty) - conditional->logNormalizationConstant(); }; @@ -539,36 +539,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); - // Iterate over each factor. for (auto &factor : factors_) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - AlgebraicDecisionTree factor_error; - - auto f = factor; - if (auto hc = dynamic_pointer_cast(factor)) { - f = hc->inner(); - } - - if (auto hybridGaussianCond = - dynamic_pointer_cast(f)) { - // Compute factor error and add it. - error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues); - } else if (auto gaussian = dynamic_pointer_cast(f)) { - // If continuous only, get the (double) error - // and add it to the error_tree - double error = gaussian->error(continuousValues); - // Add the gaussian factor error to every leaf of the error tree. - error_tree = error_tree.apply( - [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { - // If factor at `idx` is discrete-only, we skip. + if (auto f = std::dynamic_pointer_cast(factor)) { + // Check for HybridFactor, and call errorTree + error_tree = error_tree + f->errorTree(continuousValues); + } else if (auto f = std::dynamic_pointer_cast(factor)) { + // Skip discrete factors continue; } else { - throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f); + // Everything else is a continuous only factor + HybridValues hv(continuousValues, DiscreteValues()); + error_tree = error_tree + AlgebraicDecisionTree(factor->error(hv)); } } - return error_tree; } diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 6da846abe5..9852602de4 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_; + /// HybridFactor method implementation. Should not be used. + AlgebraicDecisionTree errorTree( + const VectorValues& continuousValues) const override { + throw std::runtime_error( + "HybridNonlinearFactor::error does not take VectorValues."); + } + public: HybridNonlinearFactor() = default; diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 54dc9e93fc..5a09f18d45 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { */ std::shared_ptr linearize( const Values& continuousValues) const; + + /// Expose error(const HybridValues&) method. + using Base::error; + /// @} }; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 66a54f73fe..b1c68adf3e 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -678,6 +678,55 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) { EXPECT(assert_equal(expected, errorTree, 1e-9)); } +/* ****************************************************************************/ +// Test hybrid gaussian factor graph errorTree during +// incremental operation +TEST(HybridGaussianFactorGraph, IncrementalErrorTree) { + Switching s(4); + + HybridGaussianFactorGraph graph; + graph.push_back(s.linearizedFactorGraph.at(0)); // f(X0) + graph.push_back(s.linearizedFactorGraph.at(1)); // f(X0, X1, M0) + graph.push_back(s.linearizedFactorGraph.at(2)); // f(X1, X2, M1) + graph.push_back(s.linearizedFactorGraph.at(4)); // f(X1) + graph.push_back(s.linearizedFactorGraph.at(5)); // f(X2) + graph.push_back(s.linearizedFactorGraph.at(7)); // f(M0) + graph.push_back(s.linearizedFactorGraph.at(8)); // f(M0, M1) + + HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); + EXPECT_LONGS_EQUAL(5, hybridBayesNet->size()); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = graph.errorTree(delta.continuous()); + + std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + std::vector leaves = {0.99985581, 0.4902432, 0.51936941, + 0.0097568009}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + + graph = HybridGaussianFactorGraph(); + graph.push_back(*hybridBayesNet); + graph.push_back(s.linearizedFactorGraph.at(3)); // f(X2, X3, M2) + graph.push_back(s.linearizedFactorGraph.at(6)); // f(X3) + + hybridBayesNet = graph.eliminateSequential(); + EXPECT_LONGS_EQUAL(7, hybridBayesNet->size()); + + delta = hybridBayesNet->optimize(); + auto error_tree2 = graph.errorTree(delta.continuous()); + + discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; + leaves = {0.50985198, 0.0097577296, 0.50009425, 0, + 0.52922138, 0.029127133, 0.50985105, 0.0097567964}; + AlgebraicDecisionTree expected_error2(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-7)); +} + /* ****************************************************************************/ // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment. diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 621c8708ed..347cc5f1fe 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -51,7 +51,7 @@ using symbol_shorthand::X; * Test that any linearizedFactorGraph gaussian factors are appended to the * existing gaussian factor graph in the hybrid factor graph. */ -TEST(HybridFactorGraph, GaussianFactorGraph) { +TEST(HybridNonlinearFactorGraph, GaussianFactorGraph) { HybridNonlinearFactorGraph fg; // Add a simple prior factor to the nonlinear factor graph @@ -181,7 +181,7 @@ TEST(HybridGaussianFactorGraph, HybridNonlinearFactor) { /***************************************************************************** * Test push_back on HFG makes the correct distinction. */ -TEST(HybridFactorGraph, PushBack) { +TEST(HybridNonlinearFactorGraph, PushBack) { HybridNonlinearFactorGraph fg; auto nonlinearFactor = std::make_shared>(); @@ -240,7 +240,7 @@ TEST(HybridFactorGraph, PushBack) { /**************************************************************************** * Test construction of switching-like hybrid factor graph. */ -TEST(HybridFactorGraph, Switching) { +TEST(HybridNonlinearFactorGraph, Switching) { Switching self(3); EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph.size()); @@ -250,7 +250,7 @@ TEST(HybridFactorGraph, Switching) { /**************************************************************************** * Test linearization on a switching-like hybrid factor graph. */ -TEST(HybridFactorGraph, Linearization) { +TEST(HybridNonlinearFactorGraph, Linearization) { Switching self(3); // Linearize here: @@ -263,7 +263,7 @@ TEST(HybridFactorGraph, Linearization) { /**************************************************************************** * Test elimination tree construction */ -TEST(HybridFactorGraph, EliminationTree) { +TEST(HybridNonlinearFactorGraph, EliminationTree) { Switching self(3); // Create ordering. @@ -372,7 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { /**************************************************************************** * Test partial elimination */ -TEST(HybridFactorGraph, Partial_Elimination) { +TEST(HybridNonlinearFactorGraph, Partial_Elimination) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -401,7 +401,39 @@ TEST(HybridFactorGraph, Partial_Elimination) { EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)})); } -TEST(HybridFactorGraph, PrintErrors) { +/* ****************************************************************************/ +TEST(HybridNonlinearFactorGraph, Error) { + Switching self(3); + HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph; + + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(152.791759469, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 1}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.598612289, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 0}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.703972804, fg.error(values), 1e-9); + } + { + HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 1}}, + self.linearizationPoint); + // regression + EXPECT_DOUBLES_EQUAL(151.609437912, fg.error(values), 1e-9); + } +} + +/* ****************************************************************************/ +TEST(HybridNonlinearFactorGraph, PrintErrors) { Switching self(3); // Get nonlinear factor graph and add linear factors to be holistic @@ -424,7 +456,7 @@ TEST(HybridFactorGraph, PrintErrors) { /**************************************************************************** * Test full elimination */ -TEST(HybridFactorGraph, Full_Elimination) { +TEST(HybridNonlinearFactorGraph, Full_Elimination) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -492,7 +524,7 @@ TEST(HybridFactorGraph, Full_Elimination) { /**************************************************************************** * Test printing */ -TEST(HybridFactorGraph, Printing) { +TEST(HybridNonlinearFactorGraph, Printing) { Switching self(3); auto linearizedFactorGraph = self.linearizedFactorGraph; @@ -784,7 +816,7 @@ conditional 2: Hybrid P( x2 | m0 m1) * The issue arises if we eliminate a landmark variable first since it is not * connected to a HybridFactor. */ -TEST(HybridFactorGraph, DefaultDecisionTree) { +TEST(HybridNonlinearFactorGraph, DefaultDecisionTree) { HybridNonlinearFactorGraph fg; // Add a prior on pose x0 at the origin.