diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index b3f0d69b0e..cd77e41f8e 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -28,9 +28,9 @@ namespace gtsam { /** - * Algebraic Decision Trees fix the range to double - * Just has some nice constructors and some syntactic sugar - * TODO: consider eliminating this class altogether? + * An algebraic decision tree fixes the range of a DecisionTree to double. + * Just has some nice constructors and some syntactic sugar. + * TODO(dellaert): consider eliminating this class altogether? * * @ingroup discrete */ @@ -80,20 +80,62 @@ namespace gtsam { AlgebraicDecisionTree(const L& label, double y1, double y2) : Base(label, y1, y2) {} - /** Create a new leaf function splitting on a variable */ + /** + * @brief Create a new leaf function splitting on a variable + * + * @param labelC: The label with cardinality 2 + * @param y1: The value for the first key + * @param y2: The value for the second key + * + * Example: + * @code{.cpp} + * std::pair A {"a", 2}; + * AlgebraicDecisionTree a(A, 0.6, 0.4); + * @endcode + */ AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : Base(labelC, y1, y2) {} - /** Create from keys and vector table */ + /** + * @brief Create from keys with cardinalities and a vector table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param ys: The vector table + * + * Example with three keys, A, B, and C, with cardinalities 2, 3, and 2, + * respectively, and a vector table of size 12: + * @code{.cpp} + * DiscreteKey A(0, 2), B(1, 3), C(2, 2); + * const vector cpt{ + * 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + * 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10}; + * AlgebraicDecisionTree expected(A & B & C, cpt); + * @endcode + * The table is given in the following order: + * A=0, B=0, C=0 + * A=0, B=0, C=1 + * ... + * A=1, B=1, C=1 + * Hence, the first line in the table is for A==0, and the second for A==1. + * In each line, the first two entries are for B==0, the next two for B==1, + * and the last two for B==2. Each pair is for a C value of 0 and 1. + */ AlgebraicDecisionTree // (const std::vector& labelCs, - const std::vector& ys) { + const std::vector& ys) { this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create from keys and string table */ + /** + * @brief Create from keys and string table + * + * @param labelCs: The keys, with cardinalities, given as pairs + * @param table: The string table, given as a string of doubles. + * + * @note Table needs to be in same order as the vector table in the other constructor. + */ AlgebraicDecisionTree // (const std::vector& labelCs, const std::string& table) { @@ -108,7 +150,13 @@ namespace gtsam { Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /** Create a new function splitting on a variable */ + /** + * @brief Create a range of decision trees, splitting on a single variable. + * + * @param begin: Iterator to beginning of a range of decision trees + * @param end: Iterator to end of a range of decision trees + * @param label: The label to split on + */ template AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : Base(nullptr) { diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9f3d5e8f95..4d1670bb74 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -622,7 +622,7 @@ namespace gtsam { // B=1 // A=0: 3 // A=1: 4 - // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce // exactly the same tree as above: the highest label is always the root. // However, it will be *way* faster if labels are given highest to lowest. template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index a8764a98f7..06e945cf9f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -37,9 +37,23 @@ namespace gtsam { /** - * Decision Tree - * L = label for variables - * Y = function range (any algebra), e.g., bool, int, double + * @brief a decision tree is a function from assignments to values. + * @tparam L label for variables + * @tparam Y function range (any algebra), e.g., bool, int, double + * + * After creating a decision tree on some variables, the tree can be evaluated + * on an assignment to those variables. Example: + * + * @code{.cpp} + * // Create a decision stump one one variable 'a' with values 10 and 20. + * DecisionTree tree('a', 10, 20); + * + * // Evaluate the tree on an assignment to the variable. + * int value0 = tree({{'a', 0}}); // value0 = 10 + * int value1 = tree({{'a', 1}}); // value1 = 20 + * @endcode + * + * More examples can be found in testDecisionTree.cpp * * @ingroup discrete */ @@ -132,7 +146,8 @@ namespace gtsam { NodePtr root_; protected: - /** Internal recursive function to create from keys, cardinalities, + /** + * Internal recursive function to create from keys, cardinalities, * and Y values */ template @@ -163,7 +178,13 @@ namespace gtsam { /** Create a constant */ explicit DecisionTree(const Y& y); - /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + /** + * @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` + * + * @param label The variable to split on. + * @param y1 The value for the first assignment. + * @param y2 The value for the second assignment. + */ DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index dd292cae8a..3764a01c49 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -63,11 +63,46 @@ namespace gtsam { /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); - /** Constructor from doubles */ + /** + * @brief Constructor from doubles + * + * @param keys The discrete keys. + * @param table The table of values. + * + * @throw std::invalid_argument if the size of `table` does not match the + * number of assignments. + * + * Example: + * @code{.cpp} + * DiscreteKey X(0,2), Y(1,3); + * const std::vector table {2, 5, 3, 6, 4, 7}; + * DecisionTreeFactor f1({X, Y}, table); + * @endcode + * + * The values in the table should be laid out so that the first key varies + * the slowest, and the last key the fastest. + */ DecisionTreeFactor(const DiscreteKeys& keys, - const std::vector& table); + const std::vector& table); - /** Constructor from string */ + /** + * @brief Constructor from string + * + * @param keys The discrete keys. + * @param table The table of values. + * + * @throw std::invalid_argument if the size of `table` does not match the + * number of assignments. + * + * Example: + * @code{.cpp} + * DiscreteKey X(0,2), Y(1,3); + * DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7"); + * @endcode + * + * The values in the table should be laid out so that the first key varies + * the slowest, and the last key the fastest. + */ DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); /// Single-key specialization diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 03a9a2fc75..b37b649087 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique //** evaluate conditional probability of subtree for given DiscreteValues */ double evaluate(const DiscreteValues& values) const; + + //** (Preferred) sugar for the above for given DiscreteValues */ + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } }; /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 3fbb64d506..68b7a85a7a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -42,16 +42,30 @@ class DiscreteJunctionTree; /** * @brief Main elimination function for DiscreteFactorGraph. - * - * @param factors - * @param keys - * @return GTSAM_EXPORT + * + * @param factors The factor graph to eliminate. + * @param frontalKeys An ordering for which variables to eliminate. + * @return A pair of the resulting conditional and the separator factor. * @ingroup discrete */ -GTSAM_EXPORT std::pair, DecisionTreeFactor::shared_ptr> -EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); +GTSAM_EXPORT +std::pair +EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys); + +/** + * @brief Alternate elimination function for that creates non-normalized lookup tables. + * + * @param factors The factor graph to eliminate. + * @param frontalKeys An ordering for which variables to eliminate. + * @return A pair of the resulting lookup table and the separator factor. + * @ingroup discrete + */ +GTSAM_EXPORT +std::pair +EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys); -/* ************************************************************************* */ template<> struct EliminationTraits { typedef DiscreteFactor FactorType; ///< Type of factors in factor graph @@ -61,12 +75,14 @@ template<> struct EliminationTraits typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree + /// The default dense elimination function static std::pair, boost::shared_ptr > DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { return EliminateDiscrete(factors, keys); } + /// The default ordering generation function static Ordering DefaultOrderingFunc( const FactorGraphType& graph, @@ -75,7 +91,6 @@ template<> struct EliminationTraits } }; -/* ************************************************************************* */ /** * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * Factor == DiscreteFactor @@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph /** Implicit copy/downcast constructor to override explicit template container * constructor */ - template - DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} + template + DiscreteFactorGraph(const FactorGraph& graph) : Base(graph) {} /// Destructor virtual ~DiscreteFactorGraph() {} @@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph /// @} }; // \ DiscreteFactorGraph -std::pair // -EliminateForMPE(const DiscreteFactorGraph& factors, - const Ordering& frontalKeys); - /// traits template <> struct traits : public Testable {}; diff --git a/gtsam/discrete/DiscreteJunctionTree.h b/gtsam/discrete/DiscreteJunctionTree.h index f417cf6fa9..6b70f444b3 100644 --- a/gtsam/discrete/DiscreteJunctionTree.h +++ b/gtsam/discrete/DiscreteJunctionTree.h @@ -66,4 +66,6 @@ namespace gtsam { DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); }; + /// typedef for wrapper: + using DiscreteCluster = DiscreteJunctionTree::Cluster; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index 8a6d6f9303..9ec08302bd 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { /// @} }; +/// Free version of CartesianProduct. +inline std::vector cartesianProduct(const DiscreteKeys& keys) { + return DiscreteValues::CartesianProduct(keys); +} + /// Free version of markdown. std::string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter = DefaultKeyFormatter, diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 78efd57e28..fe8cbc7f3c 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -17,6 +17,8 @@ class DiscreteKeys { }; // DiscreteValues is added in specializations/discrete.h as a std::map +std::vector cartesianProduct( + const gtsam::DiscreteKeys& keys); string markdown( const gtsam::DiscreteValues& values, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); @@ -31,27 +33,30 @@ string html(const gtsam::DiscreteValues& values, std::map> names); #include -class DiscreteFactor { +virtual class DiscreteFactor : gtsam::Factor { void print(string s = "DiscreteFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; - bool empty() const; - size_t size() const; double operator()(const gtsam::DiscreteValues& values) const; }; #include virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); - + DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::vector& spec); DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); - + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, + const std::vector& table); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + + DecisionTreeFactor(const std::vector& keys, + const std::vector& table); DecisionTreeFactor(const std::vector& keys, string table); - + DecisionTreeFactor(const gtsam::DiscreteConditional& c); void print(string s = "DecisionTreeFactor\n", @@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + size_t cardinality(gtsam::Key j) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; @@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const; string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, @@ -203,10 +211,16 @@ class DiscreteBayesTreeClique { DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); const gtsam::DiscreteConditional* conditional() const; bool isRoot() const; + size_t nrChildren() const; + const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const; + void print(string s = "DiscreteBayesTreeClique", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; void printSignature( const string& s = "Clique: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; }; class DiscreteBayesTree { @@ -220,6 +234,9 @@ class DiscreteBayesTree { bool empty() const; const DiscreteBayesTreeClique* operator[](size_t j) const; + double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; + string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; void saveGraph(string s, @@ -242,9 +259,9 @@ class DiscreteBayesTree { class DiscreteLookupTable : gtsam::DiscreteConditional{ DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, const gtsam::DecisionTreeFactor::ADT& potentials); - void print( - const std::string& s = "Discrete Lookup Table: ", - const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + void print(string s = "Discrete Lookup Table: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const; }; @@ -263,6 +280,14 @@ class DiscreteLookupDAG { }; #include +std::pair +EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors, + const gtsam::Ordering& frontalKeys); + +std::pair +EliminateForMPE(const gtsam::DiscreteFactorGraph& factors, + const gtsam::Ordering& frontalKeys); + class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); @@ -277,6 +302,7 @@ class DiscreteFactorGraph { void add(const gtsam::DiscreteKey& j, const std::vector& spec); void add(const gtsam::DiscreteKeys& keys, string spec); void add(const std::vector& keys, string spec); + void add(const std::vector& keys, const std::vector& spec); bool empty() const; size_t size() const; @@ -290,25 +316,46 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; - gtsam::DiscreteBayesNet sumProduct(); - gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet sumProduct( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); - gtsam::DiscreteLookupDAG maxProduct(); - gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); + gtsam::DiscreteLookupDAG maxProduct( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); - gtsam::DiscreteBayesNet* eliminateSequential(); - gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); + gtsam::DiscreteBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); + gtsam::DiscreteBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type, + const gtsam::DiscreteFactorGraph::Eliminate& function); gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet* eliminateSequential( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); pair - eliminatePartialSequential(const gtsam::Ordering& ordering); - - gtsam::DiscreteBayesTree* eliminateMultifrontal(); - gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); - gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); + eliminatePartialSequential(const gtsam::Ordering& ordering); + pair + eliminatePartialSequential( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); + + gtsam::DiscreteBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type, + const gtsam::DiscreteFactorGraph::Eliminate& function); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering); + gtsam::DiscreteBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); + pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); pair - eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + eliminatePartialMultifrontal( + const gtsam::Ordering& ordering, + const gtsam::DiscreteFactorGraph::Eliminate& function); string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, @@ -328,4 +375,41 @@ class DiscreteFactorGraph { std::map> names) const; }; +#include + +class DiscreteEliminationTree { + DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::VariableIndex& structure, + const gtsam::Ordering& order); + + DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph, + const gtsam::Ordering& order); + + void print( + string name = "EliminationTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::DiscreteEliminationTree& other, + double tol = 1e-9) const; +}; + +#include + +class DiscreteCluster { + gtsam::Ordering orderedFrontalKeys; + gtsam::DiscreteFactorGraph factors; + const gtsam::DiscreteCluster& operator[](size_t i) const; + size_t nrChildren() const; + void print(string s = "", const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +class DiscreteJunctionTree { + DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree); + void print( + string name = "JunctionTree: ", + const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; + size_t nrRoots() const; + const gtsam::DiscreteCluster& operator[](size_t i) const; +}; + } // namespace gtsam diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 2f385263c1..8876cc4379 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -71,6 +71,19 @@ struct traits : public Testable {}; GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) +/* ************************************************************************** */ +// Test char labels and int range +/* ************************************************************************** */ + +// Create a decision stump one one variable 'a' with values 10 and 20. +TEST(DecisionTree, Constructor) { + DecisionTree tree('a', 10, 20); + + // Evaluate the tree on an assignment to the variable. + EXPECT_LONGS_EQUAL(10, tree({{'a', 0}})); + EXPECT_LONGS_EQUAL(20, tree({{'a', 1}})); +} + /* ************************************************************************** */ // Test string labels and int range /* ************************************************************************** */ @@ -114,18 +127,47 @@ struct Ring { static inline int mul(const int& a, const int& b) { return a * b; } }; +/* ************************************************************************** */ +// Check that creating decision trees respects key order. +TEST(DecisionTree, ConstructorOrder) { + // Create labels + string A("A"), B("B"); + + const std::vector ys1 = {1, 2, 3, 4}; + DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A! + + const std::vector ys2 = {1, 3, 2, 4}; + DT tree2({{A, 2}, {B, 2}}, ys2); // slower version ! + + // Both trees will be the same, tree is order from high to low labels. + // Choice(B) + // 0 Choice(A) + // 0 0 Leaf 1 + // 0 1 Leaf 2 + // 1 Choice(A) + // 1 0 Leaf 3 + // 1 1 Leaf 4 + + EXPECT(tree2.equals(tree1)); + + // Check the values are as expected by calling the () operator: + EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}})); +} + /* ************************************************************************** */ // test DT -TEST(DecisionTree, example) { +TEST(DecisionTree, Example) { // Create labels string A("A"), B("B"), C("C"); - // create a value - Assignment x00, x01, x10, x11; - x00[A] = 0, x00[B] = 0; - x01[A] = 0, x01[B] = 1; - x10[A] = 1, x10[B] = 0; - x11[A] = 1, x11[B] = 1; + // Create assignments using brace initialization: + Assignment x00{{A, 0}, {B, 0}}; + Assignment x01{{A, 0}, {B, 1}}; + Assignment x10{{A, 1}, {B, 0}}; + Assignment x11{{A, 1}, {B, 1}}; // empty DT empty; @@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) { StringBoolTree f2(f1, bool_of_int); // Check a value - Assignment x00; - x00["A"] = 0, x00["B"] = 0; + Assignment x00 {{A, 0}, {B, 0}}; EXPECT(!f2(x00)); } @@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) { // Check some values Assignment