diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 06161c2e15..27e98fcdec 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -91,7 +91,7 @@ namespace gtsam { void dot(std::ostream& os, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter, bool showZero) const override { - std::string value = valueFormatter(constant_); + const std::string value = valueFormatter(constant_); if (showZero || value.compare("0")) os << "\"" << this->id() << "\" [label=\"" << value << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; @@ -306,7 +306,8 @@ namespace gtsam { void dot(std::ostream& os, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter, bool showZero) const override { - os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ + const std::string label = labelFormatter(label_); + os << "\"" << this->id() << "\" [shape=circle, label=\"" << label << "\"]\n"; size_t B = branches_.size(); for (size_t i = 0; i < B; i++) { diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c56818448a..d1b68f4bfb 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -147,14 +147,14 @@ namespace gtsam { size_t i; ADT result(*this); for (i = 0; i < nrFrontals; i++) { - Key j = keys()[i]; + Key j = keys_[i]; result = result.combine(j, cardinality(j), op); } - // create new factor, note we start keys after nrFrontals + // Create new factor, note we start with keys after nrFrontals: DiscreteKeys dkeys; - for (; i < keys().size(); i++) { - Key j = keys()[i]; + for (; i < keys_.size(); i++) { + Key j = keys_[i]; dkeys.push_back(DiscreteKey(j, cardinality(j))); } return std::make_shared(dkeys, result); @@ -179,24 +179,22 @@ namespace gtsam { result = result.combine(j, cardinality(j), op); } - // create new factor, note we collect keys that are not in frontalKeys /* - Due to branch merging, the labels in `result` may be missing some keys + Create new factor, note we collect keys that are not in frontalKeys. + + Due to branch merging, the labels in `result` may be missing some keys. E.g. After branch merging, we may get a ADT like: Leaf [2] 1.0204082 - This is missing the key values used for branching. + Hence, code below traverses the original keys and omits those in + frontalKeys. We loop over cardinalities, which is O(n) even for a map, and + then "contains" is a binary search on a small vector. */ - KeyVector difference, frontalKeys_(frontalKeys), keys_(keys()); - // Get the difference of the frontalKeys and the factor keys using set_difference - std::sort(keys_.begin(), keys_.end()); - std::sort(frontalKeys_.begin(), frontalKeys_.end()); - std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(), - frontalKeys_.end(), back_inserter(difference)); - DiscreteKeys dkeys; - for (Key key : difference) { - dkeys.push_back(DiscreteKey(key, cardinality(key))); + for (auto&& [key, cardinality] : cardinalities_) { + if (!frontalKeys.contains(key)) { + dkeys.push_back(DiscreteKey(key, cardinality)); + } } return std::make_shared(dkeys, result); } diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index d65a9c82b7..ffb1f0b5ac 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -20,12 +20,9 @@ #include // make sure we have traits #include // headers first to make sure no missing headers +#include #include #include // for convert only -#define DISABLE_TIMING - -#include -#include #include using namespace std; @@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) { // instrumented operators /* ************************************************************************** */ size_t muls = 0, adds = 0; -double elapsed; void resetCounts() { muls = 0; adds = 0; } void printCounts(const string& s) { #ifndef DISABLE_TIMING -cout << s << ": " << std::setw(3) << muls << " muls, " << - std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms." - << endl; + cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds + << " adds" << endl; #endif resetCounts(); } @@ -131,37 +126,35 @@ ADT create(const Signature& signature) { static size_t count = 0; const DiscreteKey& key = signature.key(); std::stringstream ss; - ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first; + ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" + << key.first; string DOTfile = ss.str(); dot(p, DOTfile); return p; } +/* ************************************************************************* */ +namespace asiaCPTs { +DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); + +ADT pA = create(A % "99/1"); +ADT pS = create(S % "50/50"); +ADT pT = create(T | A = "99/1 95/5"); +ADT pL = create(L | S = "99/1 90/10"); +ADT pB = create(B | S = "70/30 40/60"); +ADT pE = create((E | T, L) = "F T T T"); +ADT pX = create(X | E = "95/5 2/98"); +ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); +} // namespace asiaCPTs + /* ************************************************************************* */ // test Asia Joint TEST(ADT, joint) { - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), - D(7, 2); - - resetCounts(); - gttic_(asiaCPTs); - ADT pA = create(A % "99/1"); - ADT pS = create(S % "50/50"); - ADT pT = create(T | A = "99/1 95/5"); - ADT pL = create(L | S = "99/1 90/10"); - ADT pB = create(B | S = "70/30 40/60"); - ADT pE = create((E | T, L) = "F T T T"); - ADT pX = create(X | E = "95/5 2/98"); - ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); - gttoc_(asiaCPTs); - tictoc_getNode(asiaCPTsNode, asiaCPTs); - elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall(); - tictoc_reset_(); - printCounts("Asia CPTs"); + using namespace asiaCPTs; // Create joint resetCounts(); - gttic_(asiaJoint); ADT joint = pA; dot(joint, "Asia-A"); joint = apply(joint, pS, &mul); @@ -183,11 +176,12 @@ TEST(ADT, joint) { #else EXPECT_LONGS_EQUAL(508, muls); #endif - gttoc_(asiaJoint); - tictoc_getNode(asiaJointNode, asiaJoint); - elapsed = asiaJointNode->secs() + asiaJointNode->wall(); - tictoc_reset_(); printCounts("Asia joint"); +} + +/* ************************************************************************* */ +TEST(ADT, combine) { + using namespace asiaCPTs; // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S) ADT pASTL = pA; @@ -203,13 +197,11 @@ TEST(ADT, joint) { } /* ************************************************************************* */ -// test Inference with joint +// test Inference with joint, created using different ordering TEST(ADT, inference) { DiscreteKey A(0, 2), D(1, 2), // B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); - resetCounts(); - gttic_(infCPTs); ADT pA = create(A % "99/1"); ADT pS = create(S % "50/50"); ADT pT = create(T | A = "99/1 95/5"); @@ -218,15 +210,9 @@ TEST(ADT, inference) { ADT pE = create((E | T, L) = "F T T T"); ADT pX = create(X | E = "95/5 2/98"); ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); - gttoc_(infCPTs); - tictoc_getNode(infCPTsNode, infCPTs); - elapsed = infCPTsNode->secs() + infCPTsNode->wall(); - tictoc_reset_(); - // printCounts("Inference CPTs"); - // Create joint + // Create joint, note different ordering than above: different tree! resetCounts(); - gttic_(asiaProd); ADT joint = pA; dot(joint, "Joint-Product-A"); joint = apply(joint, pS, &mul); @@ -248,14 +234,9 @@ TEST(ADT, inference) { #else EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering #endif - gttoc_(asiaProd); - tictoc_getNode(asiaProdNode, asiaProd); - elapsed = asiaProdNode->secs() + asiaProdNode->wall(); - tictoc_reset_(); printCounts("Asia product"); resetCounts(); - gttic_(asiaSum); ADT marginal = joint; marginal = marginal.combine(X, &add_); dot(marginal, "Joint-Sum-ADBLEST"); @@ -270,10 +251,6 @@ TEST(ADT, inference) { #else EXPECT_LONGS_EQUAL(240, (long)adds); #endif - gttoc_(asiaSum); - tictoc_getNode(asiaSumNode, asiaSum); - elapsed = asiaSumNode->secs() + asiaSumNode->wall(); - tictoc_reset_(); printCounts("Asia sum"); } @@ -281,8 +258,6 @@ TEST(ADT, inference) { TEST(ADT, factor_graph) { DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); - resetCounts(); - gttic_(createCPTs); ADT pS = create(S % "50/50"); ADT pT = create(T % "95/5"); ADT pL = create(L | S = "99/1 90/10"); @@ -290,15 +265,9 @@ TEST(ADT, factor_graph) { ADT pX = create(X | E = "95/5 2/98"); ADT pD = create(B | E = "1/8 7/9"); ADT pB = create(B | S = "70/30 40/60"); - gttoc_(createCPTs); - tictoc_getNode(createCPTsNode, createCPTs); - elapsed = createCPTsNode->secs() + createCPTsNode->wall(); - tictoc_reset_(); - // printCounts("Create CPTs"); // Create joint resetCounts(); - gttic_(asiaFG); ADT fg = pS; fg = apply(fg, pT, &mul); fg = apply(fg, pL, &mul); @@ -312,14 +281,9 @@ TEST(ADT, factor_graph) { #else EXPECT_LONGS_EQUAL(188, (long)muls); #endif - gttoc_(asiaFG); - tictoc_getNode(asiaFGNode, asiaFG); - elapsed = asiaFGNode->secs() + asiaFGNode->wall(); - tictoc_reset_(); printCounts("Asia FG"); resetCounts(); - gttic_(marg); fg = fg.combine(X, &add_); dot(fg, "Marginalized-6X"); fg = fg.combine(T, &add_); @@ -335,83 +299,54 @@ TEST(ADT, factor_graph) { #else LONGS_EQUAL(62, adds); #endif - gttoc_(marg); - tictoc_getNode(margNode, marg); - elapsed = margNode->secs() + margNode->wall(); - tictoc_reset_(); printCounts("marginalize"); // BLESTX // Eliminate X resetCounts(); - gttic_(elimX); ADT fE = pX; dot(fE, "Eliminate-01-fEX"); fE = fE.combine(X, &add_); dot(fE, "Eliminate-02-fE"); - gttoc_(elimX); - tictoc_getNode(elimXNode, elimX); - elapsed = elimXNode->secs() + elimXNode->wall(); - tictoc_reset_(); printCounts("Eliminate X"); // Eliminate T resetCounts(); - gttic_(elimT); ADT fLE = pT; fLE = apply(fLE, pE, &mul); dot(fLE, "Eliminate-03-fLET"); fLE = fLE.combine(T, &add_); dot(fLE, "Eliminate-04-fLE"); - gttoc_(elimT); - tictoc_getNode(elimTNode, elimT); - elapsed = elimTNode->secs() + elimTNode->wall(); - tictoc_reset_(); printCounts("Eliminate T"); // Eliminate S resetCounts(); - gttic_(elimS); ADT fBL = pS; fBL = apply(fBL, pL, &mul); fBL = apply(fBL, pB, &mul); dot(fBL, "Eliminate-05-fBLS"); fBL = fBL.combine(S, &add_); dot(fBL, "Eliminate-06-fBL"); - gttoc_(elimS); - tictoc_getNode(elimSNode, elimS); - elapsed = elimSNode->secs() + elimSNode->wall(); - tictoc_reset_(); printCounts("Eliminate S"); // Eliminate E resetCounts(); - gttic_(elimE); ADT fBL2 = fE; fBL2 = apply(fBL2, fLE, &mul); fBL2 = apply(fBL2, pD, &mul); dot(fBL2, "Eliminate-07-fBLE"); fBL2 = fBL2.combine(E, &add_); dot(fBL2, "Eliminate-08-fBL2"); - gttoc_(elimE); - tictoc_getNode(elimENode, elimE); - elapsed = elimENode->secs() + elimENode->wall(); - tictoc_reset_(); printCounts("Eliminate E"); // Eliminate L resetCounts(); - gttic_(elimL); ADT fB = fBL; fB = apply(fB, fBL2, &mul); dot(fB, "Eliminate-09-fBL"); fB = fB.combine(L, &add_); dot(fB, "Eliminate-10-fB"); - gttoc_(elimL); - tictoc_getNode(elimLNode, elimL); - elapsed = elimLNode->secs() + elimLNode->wall(); - tictoc_reset_(); printCounts("Eliminate L"); } diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index d764da7bfe..a41d06c2b3 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -22,7 +22,10 @@ #include #include #include +#include #include +#include +#include using namespace std; using namespace gtsam; @@ -33,25 +36,24 @@ TEST(DecisionTreeFactor, ConstructorsMatch) { DiscreteKey X(0, 2), Y(1, 3); // Create with vector and with string - const std::vector table {2, 5, 3, 6, 4, 7}; + const std::vector table{2, 5, 3, 6, 4, 7}; DecisionTreeFactor f1({X, Y}, table); DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7"); EXPECT(assert_equal(f1, f2)); } /* ************************************************************************* */ -TEST( DecisionTreeFactor, constructors) -{ +TEST(DecisionTreeFactor, constructors) { // Declare a bunch of keys - DiscreteKey X(0,2), Y(1,3), Z(2,2); + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); // Create factors DecisionTreeFactor f1(X, {2, 8}); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); - EXPECT_LONGS_EQUAL(1,f1.size()); - EXPECT_LONGS_EQUAL(2,f2.size()); - EXPECT_LONGS_EQUAL(3,f3.size()); + EXPECT_LONGS_EQUAL(1, f1.size()); + EXPECT_LONGS_EQUAL(2, f2.size()); + EXPECT_LONGS_EQUAL(3, f3.size()); DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}}; EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9); @@ -70,7 +72,7 @@ TEST( DecisionTreeFactor, constructors) /* ************************************************************************* */ TEST(DecisionTreeFactor, Error) { // Declare a bunch of keys - DiscreteKey X(0,2), Y(1,3), Z(2,2); + DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); // Create factors DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); @@ -104,9 +106,8 @@ TEST(DecisionTreeFactor, multiplication) { } /* ************************************************************************* */ -TEST( DecisionTreeFactor, sum_max) -{ - DiscreteKey v0(0,3), v1(1,2); +TEST(DecisionTreeFactor, sum_max) { + DiscreteKey v0(0, 3), v1(1, 2); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor expected(v1, "9 12"); @@ -165,22 +166,85 @@ TEST(DecisionTreeFactor, Prune) { "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); - DecisionTreeFactor expected3( - D & C & B & A, - "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " - "0.999952870000 1.0 1.0 1.0 1.0"); + DecisionTreeFactor expected3(D & C & B & A, + "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " + "0.999952870000 1.0 1.0 1.0 1.0"); maxNrAssignments = 5; auto pruned3 = factor.prune(maxNrAssignments); EXPECT(assert_equal(expected3, pruned3)); } +/* ************************************************************************** */ +// Asia Bayes Network +/* ************************************************************************** */ + +#define DISABLE_DOT + +void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { +#ifndef DISABLE_DOT + std::vector names = {"A", "S", "T", "L", "B", "E", "X", "D"}; + auto formatter = [&](Key key) { return names[key]; }; + f.dot(filename, formatter, true); +#endif +} + +/** Convert Signature into CPT */ +DecisionTreeFactor create(const Signature& signature) { + DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); + return p; +} + +/* ************************************************************************* */ +// test Asia Joint +TEST(DecisionTreeFactor, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); + + gttic_(asiaCPTs); + DecisionTreeFactor pA = create(A % "99/1"); + DecisionTreeFactor pS = create(S % "50/50"); + DecisionTreeFactor pT = create(T | A = "99/1 95/5"); + DecisionTreeFactor pL = create(L | S = "99/1 90/10"); + DecisionTreeFactor pB = create(B | S = "70/30 40/60"); + DecisionTreeFactor pE = create((E | T, L) = "F T T T"); + DecisionTreeFactor pX = create(X | E = "95/5 2/98"); + DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + + // Create joint + gttic_(asiaJoint); + DecisionTreeFactor joint = pA; + maybeSaveDotFile(joint, "Asia-A"); + joint = joint * pS; + maybeSaveDotFile(joint, "Asia-AS"); + joint = joint * pT; + maybeSaveDotFile(joint, "Asia-AST"); + joint = joint * pL; + maybeSaveDotFile(joint, "Asia-ASTL"); + joint = joint * pB; + maybeSaveDotFile(joint, "Asia-ASTLB"); + joint = joint * pE; + maybeSaveDotFile(joint, "Asia-ASTLBE"); + joint = joint * pX; + maybeSaveDotFile(joint, "Asia-ASTLBEX"); + joint = joint * pD; + maybeSaveDotFile(joint, "Asia-ASTLBEXD"); + + // Check that discrete keys are as expected + EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D})); + + // Check that summing out variables maintains the keys even if merged, as is + // the case with S. + auto noAB = joint.sum(Ordering{A.first, B.first}); + EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D})); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, DotWithNames) { DiscreteKey A(12, 3), B(5, 2); DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; - for (bool showZero:{true, false}) { + for (bool showZero : {true, false}) { string actual = f.dot(formatter, showZero); // pretty weak test, as ids are pointers and not stable across platforms. string expected = "digraph G {"; diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index a541722c44..f5be6ded87 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -22,7 +22,7 @@ namespace gtsam { /* *******************************************************************************/ static void checkKeys(const KeyVector& continuousKeys, - std::vector& pairs) { + const std::vector& pairs) { KeySet factor_keys_set; for (const auto& pair : pairs) { auto f = pair.first; @@ -55,14 +55,9 @@ HybridNonlinearFactor::HybridNonlinearFactor( /* *******************************************************************************/ HybridNonlinearFactor::HybridNonlinearFactor( const KeyVector& continuousKeys, const DiscreteKey& discreteKey, - const std::vector& factors) + const std::vector& pairs) : Base(continuousKeys, {discreteKey}) { - std::vector pairs; KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); - KeySet factor_keys_set; - for (auto&& [f, val] : factors) { - pairs.emplace_back(f, val); - } checkKeys(continuousKeys, pairs); factors_ = FactorValuePairs({discreteKey}, pairs); } diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index 766467518b..7843afc836 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -106,11 +106,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * * @param continuousKeys Vector of keys for continuous factors. * @param discreteKey The discrete key for the "mode", indexing components. - * @param factors Vector of gaussian factor-scalar pairs, one per mode. + * @param pairs Vector of gaussian factor-scalar pairs, one per mode. */ HybridNonlinearFactor(const KeyVector& continuousKeys, const DiscreteKey& discreteKey, - const std::vector& factors); + const std::vector& pairs); /** * @brief Construct a new HybridNonlinearFactor on a several discrete keys M,