From b5a3f11993d430f6ff26cc319c943f87077ab817 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 13 Feb 2023 15:28:35 -0500 Subject: [PATCH 01/20] Add better hybrid support --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 12 ++-- gtsam/hybrid/HybridNonlinearFactorGraph.cpp | 7 ++ .../tests/testGaussianMixtureFactor.cpp | 2 + .../tests/testHybridGaussianFactorGraph.cpp | 67 +++++++++++++++++-- 4 files changed, 75 insertions(+), 13 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c912a74fc7..20950d4f9d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -106,7 +106,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { // TODO(dellaert): just use a virtual method defined in HybridFactor. if (auto gf = dynamic_pointer_cast(f)) { result = addGaussian(result, gf); - } else if (auto gm = dynamic_pointer_cast(f)) { + } else if (auto gmf = dynamic_pointer_cast(f)) { + result = gmf->add(result); + } else if (auto gm = dynamic_pointer_cast(f)) { result = gm->add(result); } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asMixture()) { @@ -283,17 +285,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // taking care to correct for conditional constant. // Correct for the normalization constant used up by the conditional - auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr { + auto correct = [&](const Result &pair) { const auto &factor = pair.second; - if (!factor) return factor; // TODO(dellaert): not loving this. + if (!factor) return; auto hf = boost::dynamic_pointer_cast(factor); if (!hf) throw std::runtime_error("Expected HessianFactor!"); hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant(); - return hf; }; + eliminationResults.visit(correct); - GaussianMixtureFactor::Factors correctedFactors(eliminationResults, - correct); const auto mixtureFactor = boost::make_shared( continuousSeparator, discreteSeparator, newFactors); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 71b064eb60..ab64b1ec3b 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -17,6 +17,7 @@ */ #include +#include #include #include #include @@ -69,6 +70,12 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); + } else if (auto gmf = dynamic_pointer_cast(f)) { + linearFG->push_back(gmf); + } else if (auto gm = dynamic_pointer_cast(f)) { + linearFG->push_back(gm); + } else if (dynamic_pointer_cast(f)) { + linearFG->push_back(f); } else { auto& fr = *f; throw std::invalid_argument( diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 962d238a89..4ef2af471e 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -93,6 +93,7 @@ TEST(GaussianMixtureFactor, Sum) { EXPECT(actual.at(1) == f22); } +/* ************************************************************************* */ TEST(GaussianMixtureFactor, Printing) { DiscreteKey m1(1, 2); auto A1 = Matrix::Zero(2, 1); @@ -136,6 +137,7 @@ TEST(GaussianMixtureFactor, Printing) { EXPECT(assert_print_equal(expected, mixtureFactor)); } +/* ************************************************************************* */ TEST(GaussianMixtureFactor, GaussianMixture) { KeyVector keys; keys.push_back(X(0)); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index e5c11bf0c2..3302994c0f 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -612,7 +612,6 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { // Check that assembleGraphTree assembles Gaussian factor graphs for each // assignment. TEST(HybridGaussianFactorGraph, assembleGraphTree) { - using symbol_shorthand::Z; const int num_measurements = 1; auto fg = tiny::createHybridGaussianFactorGraph( num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); @@ -694,7 +693,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, /* ****************************************************************************/ // Check that eliminating tiny net with 1 measurement yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny1) { - using symbol_shorthand::Z; const int num_measurements = 1; const VectorValues measurements{{Z(0), Vector1(5.0)}}; auto bn = tiny::createHybridBayesNet(num_measurements); @@ -726,11 +724,67 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { EXPECT(ratioTest(bn, measurements, *posterior)); } +/* ****************************************************************************/ +// Check that eliminating tiny net with 1 measurement with mode order swapped +// yields correct result. +TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) { + const VectorValues measurements{{Z(0), Vector1(5.0)}}; + + // Create mode key: 1 is low-noise, 0 is high-noise. + const DiscreteKey mode{M(0), 2}; + HybridBayesNet bn; + + // Create Gaussian mixture z_0 = x0 + noise for each measurement. + bn.emplace_back(new GaussianMixture( + {Z(0)}, {X(0)}, {mode}, + {GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3), + GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, + 0.5)})); + + // Create prior on X(0). + bn.push_back( + GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5)); + + // Add prior on mode. + bn.emplace_back(new DiscreteConditional(mode, "1/1")); + + // bn.print(); + auto fg = bn.toFactorGraph(measurements); + EXPECT_LONGS_EQUAL(3, fg.size()); + + // fg.print(); + + EXPECT(ratioTest(bn, measurements, fg)); + + // Create expected Bayes Net: + HybridBayesNet expectedBayesNet; + + // Create Gaussian mixture on X(0). + // regression, but mean checked to be 5.0 in both cases: + const auto conditional0 = boost::make_shared( + X(0), Vector1(10.1379), I_1x1 * 2.02759), + conditional1 = boost::make_shared( + X(0), Vector1(14.1421), I_1x1 * 2.82843); + expectedBayesNet.emplace_back( + new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1})); + + // Add prior on mode. + expectedBayesNet.emplace_back(new DiscreteConditional(mode, "1/1")); + + // Test elimination + const auto posterior = fg.eliminateSequential(); + // EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + + EXPECT(ratioTest(bn, measurements, *posterior)); + + // posterior->print(); + // posterior->optimize().print(); +} + /* ****************************************************************************/ // Check that eliminating tiny net with 2 measurements yields correct result. TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Create factor graph with 2 measurements such that posterior mean = 5.0. - using symbol_shorthand::Z; const int num_measurements = 2; const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; auto bn = tiny::createHybridBayesNet(num_measurements); @@ -764,7 +818,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { // Test eliminating tiny net with 1 mode per measurement. TEST(HybridGaussianFactorGraph, EliminateTiny22) { // Create factor graph with 2 measurements such that posterior mean = 5.0. - using symbol_shorthand::Z; const int num_measurements = 2; const bool manyModes = true; @@ -835,12 +888,12 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // D D // | | // m1 m2 - // | | + // | | // C-x0-HC-x1-HC-x2 // | | | // HF HF HF // | | | - // n0 n1 n2 + // n0 n1 n2 // | | | // D D D EXPECT_LONGS_EQUAL(11, fg.size()); @@ -853,7 +906,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { EXPECT(ratioTest(bn, measurements, fg1)); // Create ordering that eliminates in time order, then discrete modes: - Ordering ordering {X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)}; + Ordering ordering{X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)}; // Do elimination: const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering); From 2714dc562550b5c6f0b774f2b215172900ef6dc3 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 13 Feb 2023 15:59:30 -0500 Subject: [PATCH 02/20] add ordering method for HybridSmoother --- gtsam/hybrid/HybridSmoother.cpp | 32 +++++++++++++++++++++++++++++++- gtsam/hybrid/HybridSmoother.h | 2 ++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 35dd5f88bc..fcee7833a8 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -23,6 +23,37 @@ namespace gtsam { +/* ************************************************************************* */ +Ordering HybridSmoother::getOrdering( + const HybridGaussianFactorGraph &newFactors) { + HybridGaussianFactorGraph factors(hybridBayesNet()); + factors += newFactors; + // Get all the discrete keys from the factors + KeySet allDiscrete = factors.discreteKeySet(); + + // Create KeyVector with continuous keys followed by discrete keys. + KeyVector newKeysDiscreteLast; + const KeySet newFactorKeys = newFactors.keys(); + // Insert continuous keys first. + for (auto &k : newFactorKeys) { + if (!allDiscrete.exists(k)) { + newKeysDiscreteLast.push_back(k); + } + } + + // Insert discrete keys at the end + std::copy(allDiscrete.begin(), allDiscrete.end(), + std::back_inserter(newKeysDiscreteLast)); + + const VariableIndex index(newFactors); + + // Get an ordering where the new keys are eliminated last + Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), + true); + return ordering; +} + /* ************************************************************************* */ void HybridSmoother::update(HybridGaussianFactorGraph graph, const Ordering &ordering, @@ -92,7 +123,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, } graph.push_back(newConditionals); - // newConditionals.print("\n\n\nNew Conditionals to add back"); } return {graph, hybridBayesNet}; } diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 7e90f9425d..9f14a70022 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -50,6 +50,8 @@ class HybridSmoother { void update(HybridGaussianFactorGraph graph, const Ordering& ordering, boost::optional maxNrLeaves = boost::none); + Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); + /** * @brief Add conditionals from previous timestep as part of liquefication. * From febeacd68686ed0b7ced72458eb0b31a196bdab7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 4 Jun 2023 15:40:02 +0100 Subject: [PATCH 03/20] Improved documentation and tests --- gtsam/discrete/AlgebraicDecisionTree.h | 64 +++++++++++++++++++--- gtsam/discrete/DecisionTree-inl.h | 2 +- gtsam/discrete/DecisionTree.h | 31 +++++++++-- gtsam/discrete/tests/testDecisionTree.cpp | 66 ++++++++++++++++++----- 4 files changed, 137 insertions(+), 26 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index b3f0d69b0e..cd77e41f8e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,9 +28,9 @@ namespace gtsam { /** - * Algebraic Decision Trees fix the range to double - * Just has some nice constructors and some syntactic sugar - * TODO: consider eliminating this class altogether? + * An algebraic decision tree fixes the range of a DecisionTree to double. + * Just has some nice constructors and some syntactic sugar. + * TODO(dellaert): consider eliminating this class altogether? * * @ingroup discrete */ @@ -80,20 +80,62 @@ namespace gtsam { AlgebraicDecisionTree(const L& label, double y1, double y2) : Base(label, y1, y2) {} - /** Create a new leaf function splitting on a variable */ + /** + * @brief Create a new leaf function splitting on a variable + * + * @param labelC: The label with cardinality 2 + * @param y1: The value for the first key + * @param y2: The value for the second key + * + * Example: + * @code{.cpp} + * std::pair A {"a", 2}; + * AlgebraicDecisionTree a(A, 0.6, 0.4); + * @endcode + */ AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : Base(labelC, y1, y2) {} - /** Create from keys and vector table */ + /** + * @brief Create from keys with cardinalities and a vector table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param ys: The vector table + * + * Example with three keys, A, B, and C, with cardinalities 2, 3, and 2, + * respectively, and a vector table of size 12: + * @code{.cpp} + * DiscreteKey A(0, 2), B(1, 3), C(2, 2); + * const vector cpt{ + * 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + * 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10}; + * AlgebraicDecisionTree expected(A & B & C, cpt); + * @endcode + * The table is given in the following order: + * A=0, B=0, C=0 + * A=0, B=0, C=1 + * ... + * A=1, B=1, C=1 + * Hence, the first line in the table is for A==0, and the second for A==1. + * In each line, the first two entries are for B==0, the next two for B==1, + * and the last two for B==2. Each pair is for a C value of 0 and 1. + */ AlgebraicDecisionTree // (const std::vector& labelCs, - const std::vector& ys) { + const std::vector& ys) { this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create from keys and string table */ + /** + * @brief Create from keys and string table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param table: The string table, given as a string of doubles. + * + * @note Table needs to be in same order as the vector table in the other constructor. + */ AlgebraicDecisionTree // (const std::vector& labelCs, const std::string& table) { @@ -108,7 +150,13 @@ namespace gtsam { Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create a new function splitting on a variable */ + /** + * @brief Create a range of decision trees, splitting on a single variable. + * + * @param begin: Iterator to beginning of a range of decision trees + * @param end: Iterator to end of a range of decision trees + * @param label: The label to split on + */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : Base(nullptr) { diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9f3d5e8f95..4d1670bb74 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -622,7 +622,7 @@ namespace gtsam { // B=1 // A=0: 3 // A=1: 4 - // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce // exactly the same tree as above: the highest label is always the root. // However, it will be *way* faster if labels are given highest to lowest. template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index a8764a98f7..06e945cf9f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -37,9 +37,23 @@ namespace gtsam { /** - * Decision Tree - * L = label for variables - * Y = function range (any algebra), e.g., bool, int, double + * @brief a decision tree is a function from assignments to values. + * @tparam L label for variables + * @tparam Y function range (any algebra), e.g., bool, int, double + * + * After creating a decision tree on some variables, the tree can be evaluated + * on an assignment to those variables. Example: + * + * @code{.cpp} + * // Create a decision stump one one variable 'a' with values 10 and 20. + * DecisionTree tree('a', 10, 20); + * + * // Evaluate the tree on an assignment to the variable. + * int value0 = tree({{'a', 0}}); // value0 = 10 + * int value1 = tree({{'a', 1}}); // value1 = 20 + * @endcode + * + * More examples can be found in testDecisionTree.cpp * * @ingroup discrete */ @@ -132,7 +146,8 @@ namespace gtsam { NodePtr root_; protected: - /** Internal recursive function to create from keys, cardinalities, + /** + * Internal recursive function to create from keys, cardinalities, * and Y values */ template @@ -163,7 +178,13 @@ namespace gtsam { /** Create a constant */ explicit DecisionTree(const Y& y); - /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + /** + * @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + * + * @param label The variable to split on. + * @param y1 The value for the first assignment. + * @param y2 The value for the second assignment. + */ DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 2f385263c1..fbcecb5abb 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -71,6 +71,19 @@ struct traits : public Testable {}; GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) +/* ************************************************************************** */ +// Test char labels and int range +/* ************************************************************************** */ + +// Create a decision stump one one variable 'a' with values 10 and 20. +TEST(DecisionTree, constructor) { + DecisionTree tree('a', 10, 20); + + // Evaluate the tree on an assignment to the variable. + EXPECT_LONGS_EQUAL(10, tree({{'a', 0}})); + EXPECT_LONGS_EQUAL(20, tree({{'a', 1}})); +} + /* ************************************************************************** */ // Test string labels and int range /* ************************************************************************** */ @@ -114,18 +127,47 @@ struct Ring { static inline int mul(const int& a, const int& b) { return a * b; } }; +/* ************************************************************************** */ +// Check that creating decision trees respects key order. +TEST(DecisionTree, constructor_order) { + // Create labels + string A("A"), B("B"); + + const std::vector ys1 = {1, 2, 3, 4}; + DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A! + + const std::vector ys2 = {1, 3, 2, 4}; + DT tree2({{A, 2}, {B, 2}}, ys2); // slower version ! + + // Both trees will be the same, tree is order from high to low labels. + // Choice(B) + // 0 Choice(A) + // 0 0 Leaf 1 + // 0 1 Leaf 2 + // 1 Choice(A) + // 1 0 Leaf 3 + // 1 1 Leaf 4 + + EXPECT(tree2.equals(tree1)); + + // Check the values are as expected by calling the () operator: + EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}})); +} + /* ************************************************************************** */ // test DT TEST(DecisionTree, example) { // Create labels string A("A"), B("B"), C("C"); - // create a value - Assignment x00, x01, x10, x11; - x00[A] = 0, x00[B] = 0; - x01[A] = 0, x01[B] = 1; - x10[A] = 1, x10[B] = 0; - x11[A] = 1, x11[B] = 1; + // Create assignments using brace initialization: + Assignment x00{{A, 0}, {B, 0}}; + Assignment x01{{A, 0}, {B, 1}}; + Assignment x10{{A, 1}, {B, 0}}; + Assignment x11{{A, 1}, {B, 1}}; // empty DT empty; @@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) { StringBoolTree f2(f1, bool_of_int); // Check a value - Assignment x00; - x00["A"] = 0, x00["B"] = 0; + Assignment x00 {{A, 0}, {B, 0}}; EXPECT(!f2(x00)); } @@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) { // Check some values Assignment