diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 5fb5ae2e61..ff18268b14 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -33,16 +33,13 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) - : DiscreteFactor(keys.indices()), - ADT(potentials), - cardinalities_(keys.cardinalities()) {} + const ADT& potentials) + : DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {} /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) - : DiscreteFactor(c.keys()), - AlgebraicDecisionTree(c), - cardinalities_(c.cardinalities_) {} + : DiscreteFactor(c.keys(), c.cardinalities()), + AlgebraicDecisionTree(c) {} /* ************************************************************************ */ bool DecisionTreeFactor::equals(const DiscreteFactor& other, @@ -182,15 +179,12 @@ namespace gtsam { } /* ************************************************************************ */ - DiscreteKeys DecisionTreeFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } + std::vector DecisionTreeFactor::probabilities() const { + std::vector probs; + for (auto&& [key, value] : enumerate()) { + probs.push_back(value); } - return result; + return probs; } /* ************************************************************************ */ @@ -288,17 +282,15 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const vector& table) - : DiscreteFactor(keys.indices()), - AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) {} + const vector& table) + : DiscreteFactor(keys.indices(), keys.cardinalities()), + AlgebraicDecisionTree(keys, table) {} /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const string& table) - : DiscreteFactor(keys.indices()), - AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) {} + const string& table) + : DiscreteFactor(keys.indices(), keys.cardinalities()), + AlgebraicDecisionTree(keys, table) {} /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const { @@ -306,11 +298,10 @@ namespace gtsam { // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; - this->visitLeaf([&](const Leaf& leaf) { - size_t nrAssignments = leaf.nrAssignments(); - double prob = leaf.constant(); - probabilities.insert(probabilities.end(), nrAssignments, prob); - }); + // NOTE(Varun) this is potentially slow due to the cartesian product + for (auto&& [assignment, prob] : this->enumerate()) { + probabilities.push_back(prob); + } // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 95054bcdb9..6cce6e5d4d 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -50,10 +50,6 @@ namespace gtsam { typedef std::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; - protected: - std::map cardinalities_; - - public: /// @name Standard Constructors /// @{ @@ -119,8 +115,6 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - size_t cardinality(Key j) const { return cardinalities_.at(j); } - /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { return apply(f, safe_div); @@ -179,8 +173,8 @@ namespace gtsam { /// Enumerate all values into a map from values to double. std::vector> enumerate() const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; + /// Get all the probabilities in order of assignment values + std::vector probabilities() const; /** * @brief Prune the decision tree of discrete variables. @@ -260,7 +254,6 @@ namespace gtsam { void serialize(ARCHIVE& ar, const unsigned int /*version*/) { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); - ar& BOOST_SERIALIZATION_NVP(cardinalities_); } #endif }; diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index 2b1bc36a3a..b44d4fce2e 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -28,6 +28,18 @@ using namespace std; namespace gtsam { +/* ************************************************************************ */ +DiscreteKeys DiscreteFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + /* ************************************************************************* */ double DiscreteFactor::error(const DiscreteValues& values) const { return -std::log((*this)(values)); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index f00ebc4993..24b2b55e4c 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -36,28 +36,35 @@ class HybridValues; * @ingroup discrete */ class GTSAM_EXPORT DiscreteFactor: public Factor { - -public: - + public: // typedefs needed to play nice with gtsam - typedef DiscreteFactor This; ///< This class - typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class - typedef Factor Base; ///< Our base class + typedef DiscreteFactor This; ///< This class + typedef std::shared_ptr + shared_ptr; ///< shared_ptr to this class + typedef Factor Base; ///< Our base class - using Values = DiscreteValues; ///< backwards compatibility + using Values = DiscreteValues; ///< backwards compatibility -public: + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ /** Default constructor creates empty factor */ DiscreteFactor() {} - /** Construct from container of keys. This constructor is used internally from derived factor - * constructors, either from a container of keys or from a boost::assign::list_of. */ - template - DiscreteFactor(const CONTAINER& keys) : Base(keys) {} + /** + * Construct from container of keys and map of cardinalities. + * This constructor is used internally from derived factor constructors, + * either from a container of keys or from a boost::assign::list_of. + */ + template + DiscreteFactor(const CONTAINER& keys, + const std::map cardinalities = {}) + : Base(keys), cardinalities_(cardinalities) {} /// @} /// @name Testable @@ -77,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @name Standard Interface /// @{ + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; + + std::map cardinalities() const { return cardinalities_; } + + size_t cardinality(Key j) const { return cardinalities_.at(j); } + /// Find value for given assignment of values to variables virtual double operator()(const DiscreteValues&) const = 0; @@ -124,6 +138,17 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { const Names& names = {}) const = 0; /// @} + + private: +#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(cardinalities_); + } +#endif }; // DiscreteFactor diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 5fe3cd9d16..74eb3ddb38 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -13,11 +13,12 @@ * @file TableFactor.cpp * @brief discrete factor * @date May 4, 2023 - * @author Yoonwoo Kim + * @author Yoonwoo Kim, Varun Agrawal */ #include #include +#include #include #include @@ -33,8 +34,7 @@ TableFactor::TableFactor() {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const TableFactor& potentials) - : DiscreteFactor(dkeys.indices()), - cardinalities_(potentials.cardinalities_) { + : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) { sparse_table_ = potentials.sparse_table_; denominators_ = potentials.denominators_; sorted_dkeys_ = discreteKeys(); @@ -44,11 +44,11 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const Eigen::SparseVector& table) - : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + : DiscreteFactor(dkeys.indices(), dkeys.cardinalities()), + sparse_table_(table.size()) { sparse_table_ = table; double denom = table.size(); for (const DiscreteKey& dkey : dkeys) { - cardinalities_.insert(dkey); denom /= dkey.second; denominators_.insert(std::pair(dkey.first, denom)); } @@ -56,6 +56,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteConditional& c) + : TableFactor(c.discreteKeys(), c.probabilities()) {} + /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( const std::vector& table) { @@ -435,18 +439,6 @@ std::vector> TableFactor::enumerate() const { return result; } -/* ************************************************************************ */ -DiscreteKeys TableFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; -} - // Print out header. /* ************************************************************************ */ string TableFactor::markdown(const KeyFormatter& keyFormatter, diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 1462180e03..bd637bb7d3 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -12,7 +12,7 @@ /** * @file TableFactor.h * @date May 4, 2023 - * @author Yoonwoo Kim + * @author Yoonwoo Kim, Varun Agrawal */ #pragma once @@ -32,6 +32,7 @@ namespace gtsam { +class DiscreteConditional; class HybridValues; /** @@ -44,8 +45,6 @@ class HybridValues; */ class GTSAM_EXPORT TableFactor : public DiscreteFactor { protected: - /// Map of Keys and their cardinalities. - std::map cardinalities_; /// SparseVector of nonzero probabilities. Eigen::SparseVector sparse_table_; @@ -57,10 +56,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /** * @brief Uses lazy cartesian product to find nth entry in the cartesian - * product of arrays in O(1) - * Example) - * v0 | v1 | val - * 0 | 0 | 10 + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 * 0 | 1 | 21 * 1 | 0 | 32 * 1 | 1 | 43 @@ -75,13 +74,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { * @brief Return ith key in keys_ as a DiscreteKey * @param i ith key in keys_ * @return DiscreteKey - * */ + */ DiscreteKey discreteKey(size_t i) const { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } /// Convert probability table given as doubles to SparseVector. - /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} + /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} static Eigen::SparseVector Convert(const std::vector& table); /// Convert probability table given as string to SparseVector. @@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /** Construct from a DiscreteConditional type */ + explicit TableFactor(const DiscreteConditional& c); + /// @} /// @name Testable /// @{ @@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { static double safe_div(const double& a, const double& b); - size_t cardinality(Key j) const { return cardinalities_.at(j); } - /// divide by factor f (safely) TableFactor operator/(const TableFactor& f) const { return apply(f, safe_div); @@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Enumerate all values into a map from values to double. std::vector> enumerate() const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; - /** * @brief Prune the decision tree of discrete variables. * diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 3dbb3e64f3..57584a03b5 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors) // Assert that error = -log(value) EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); + + // Construct from DiscreteConditional + DiscreteConditional conditional(X | Y = "1/1 2/3 1/4"); + DecisionTreeFactor f4(conditional); + EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 3ad7573472..b307d78f6a 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -93,7 +93,8 @@ void printTime(map> for (auto&& kv : measured_time) { cout << "dropout: " << kv.first << " | TableFactor time: " << kv.second.first.count() - << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; + << " | DecisionTreeFactor time: " << kv.second.second.count() << + endl; } } @@ -124,6 +125,13 @@ TEST(TableFactor, constructors) { // Assert that error = -log(value) EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); + + // Construct from DiscreteConditional + DiscreteConditional conditional(X | Y = "1/1 2/3 1/4"); + TableFactor f4(conditional); + // Manually constructed via inspection and comparison to DecisionTreeFactor + TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); + EXPECT(assert_equal(expected, f4)); } /* ************************************************************************* */ @@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) { /* ************************************************************************* */ // Benchmark which compares runtime of multiplication of two TableFactors // and two DecisionTreeFactors given sparsity from dense to 90% sparsity. -TEST(TableFactor, benchmark) { +// NOTE: Enable to run. +TEST_DISABLED(TableFactor, benchmark) { DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 546d0200b5..659d444238 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -228,19 +228,19 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /** * @brief Helper function to get the pruner functional. * - * @param decisionTree The probability decision tree of only discrete keys. + * @param discreteProbs The probabilities of only discrete keys. * @return std::function &, const GaussianConditional::shared_ptr &)> */ std::function &, const GaussianConditional::shared_ptr &)> -GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { +GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { // Get the discrete keys as sets for the decision tree // and the gaussian mixture. - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); - auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( + auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet]( const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { @@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { // Case where the gaussian mixture has the same // discrete keys as the decision tree. - if (gaussianMixtureKeySet == decisionTreeKeySet) { - if (decisionTree(values) == 0.0) { + if (gaussianMixtureKeySet == discreteProbsKeySet) { + if (discreteProbs(values) == 0.0) { // empty aka null pointer std::shared_ptr null; return null; @@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { } } else { std::vector set_diff; - std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), - gaussianMixtureKeySet.begin(), - gaussianMixtureKeySet.end(), - std::back_inserter(set_diff)); + std::set_difference( + discreteProbsKeySet.begin(), discreteProbsKeySet.end(), + gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); const std::vector assignments = DiscreteValues::CartesianProduct(set_diff); @@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { // If any one of the sub-branches are non-zero, // we need this conditional. - if (decisionTree(augmented_values) > 0.0) { + if (discreteProbs(augmented_values) > 0.0) { return conditional; } } @@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { } /* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); +void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = prunerFunc(decisionTree); + auto pruner = prunerFunc(discreteProbs); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2d715c6e31..0b68fcfd05 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Helper function to get the pruner functor. * - * @param decisionTree The pruned discrete probability decision tree. + * @param discreteProbs The pruned discrete probabilities. * @return std::function &, const GaussianConditional::shared_ptr &)> */ std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &decisionTree); + prunerFunc(const DecisionTreeFactor &discreteProbs); public: /// @name Constructors @@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Prune the decision tree of Gaussian factors as per the discrete - * `decisionTree`. + * `discreteProbs`. * - * @param decisionTree A pruned decision tree of discrete keys where the - * leaves are probabilities. + * @param discreteProbs A pruned set of probabilities for the discrete keys. */ - void prune(const DecisionTreeFactor &decisionTree); + void prune(const DecisionTreeFactor &discreteProbs); /** * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 68129bc27a..266e02b0dd 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { - AlgebraicDecisionTree decisionTree; + AlgebraicDecisionTree discreteProbs; // The canonical decision tree factor which will get // the discrete conditionals added to it. - DecisionTreeFactor dtFactor; + DecisionTreeFactor discreteProbsFactor; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { // Convert to a DecisionTreeFactor and add it to the main factor. DecisionTreeFactor f(*conditional->asDiscrete()); - dtFactor = dtFactor * f; + discreteProbsFactor = discreteProbsFactor * f; } } - return std::make_shared(dtFactor); + return std::make_shared(discreteProbsFactor); } /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. * - * @param prunedDecisionTree The prob. decision tree of only discrete keys. + * @param prunedDiscreteProbs The prob. decision tree of only discrete keys. * @param conditional Conditional to prune. Used to get full assignment. * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDecisionTree, + const DecisionTreeFactor &prunedDiscreteProbs, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. - std::set decisionTreeKeySet = - DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); + std::set discreteProbsKeySet = + DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys()); std::set conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); - auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( + auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet]( const Assignment &choices, double probability) -> double { // This corresponds to 0 probability @@ -83,8 +83,8 @@ std::function &, double)> prunerFunc( DiscreteValues values(choices); // Case where the Gaussian mixture has the same // discrete keys as the decision tree. - if (conditionalKeySet == decisionTreeKeySet) { - if (prunedDecisionTree(values) == 0) { + if (conditionalKeySet == discreteProbsKeySet) { + if (prunedDiscreteProbs(values) == 0) { return pruned_prob; } else { return probability; @@ -114,11 +114,12 @@ std::function &, double)> prunerFunc( } // Now we generate the full assignment by enumerating - // over all keys in the prunedDecisionTree. + // over all keys in the prunedDiscreteProbs. // First we find the differing keys std::vector set_diff; - std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), - conditionalKeySet.begin(), conditionalKeySet.end(), + std::set_difference(discreteProbsKeySet.begin(), + discreteProbsKeySet.end(), conditionalKeySet.begin(), + conditionalKeySet.end(), std::back_inserter(set_diff)); // Now enumerate over all assignments of the differing keys @@ -130,7 +131,7 @@ std::function &, double)> prunerFunc( // If any one of the sub-branches are non-zero, // we need this probability. - if (prunedDecisionTree(augmented_values) > 0.0) { + if (prunedDiscreteProbs(augmented_values) > 0.0) { return probability; } } @@ -144,8 +145,8 @@ std::function &, double)> prunerFunc( /* ************************************************************************* */ void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDecisionTree) { - KeyVector prunedTreeKeys = prunedDecisionTree.keys(); + const DecisionTreeFactor &prunedDiscreteProbs) { + KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { @@ -153,18 +154,21 @@ void HybridBayesNet::updateDiscreteConditionals( if (conditional->isDiscrete()) { auto discrete = conditional->asDiscrete(); - // Apply prunerFunc to the underlying AlgebraicDecisionTree + // Convert pointer from conditional to factor auto discreteTree = std::dynamic_pointer_cast(discrete); + // Apply prunerFunc to the underlying AlgebraicDecisionTree DecisionTreeFactor::ADT prunedDiscreteTree = - discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); + discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional)); + gttic_(HybridBayesNet_MakeConditional); // Create the new (hybrid) conditional KeyVector frontals(discrete->frontals().begin(), discrete->frontals().end()); auto prunedDiscrete = std::make_shared( frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); conditional = std::make_shared(prunedDiscrete); + gttoc_(HybridBayesNet_MakeConditional); // Add it back to the BayesNet this->at(i) = conditional; @@ -175,10 +179,16 @@ void HybridBayesNet::updateDiscreteConditionals( /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Get the decision tree of only the discrete keys - auto discreteConditionals = this->discreteConditionals(); - const auto decisionTree = discreteConditionals->prune(maxNrLeaves); + gttic_(HybridBayesNet_PruneDiscreteConditionals); + DecisionTreeFactor::shared_ptr discreteConditionals = + this->discreteConditionals(); + const DecisionTreeFactor prunedDiscreteProbs = + discreteConditionals->prune(maxNrLeaves); + gttoc_(HybridBayesNet_PruneDiscreteConditionals); - this->updateDiscreteConditionals(decisionTree); + gttic_(HybridBayesNet_UpdateDiscreteConditionals); + this->updateDiscreteConditionals(prunedDiscreteProbs); + gttoc_(HybridBayesNet_UpdateDiscreteConditionals); /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree @@ -189,13 +199,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet prunedBayesNetFragment; + gttic_(HybridBayesNet_PruneMixtures); // Go through all the conditionals in the - // Bayes Net and prune them as per decisionTree. + // Bayes Net and prune them as per prunedDiscreteProbs. for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // Make a copy of the Gaussian mixture and prune it! auto prunedGaussianMixture = std::make_shared(*gm); - prunedGaussianMixture->prune(decisionTree); // imperative :-( + prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-( // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back(prunedGaussianMixture); @@ -205,6 +216,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { prunedBayesNetFragment.push_back(conditional); } } + gttoc_(HybridBayesNet_PruneMixtures); return prunedBayesNetFragment; } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 2b0042b8dd..23fc4d5d30 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /** * @brief Update the discrete conditionals with the pruned versions. * - * @param prunedDecisionTree + * @param prunedDiscreteProbs */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); + void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index b252e613e5..ae8fa03781 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto decisionTree = - this->roots_.at(0)->conditional()->asDiscrete(); + auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDecisionTree.root_; + DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDecisionTree; - HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, + DecisionTreeFactor prunedDiscreteProbs; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) - : prunedDecisionTree(prunedDecisionTree) {} + : prunedDiscreteProbs(prunedDiscreteProbs) {} /** * @brief A function used during tree traversal that operates on each node @@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (conditional->isHybrid()) { auto gaussianMixture = conditional->asMixture(); - gaussianMixture->prune(parentData.prunedDecisionTree); + gaussianMixture->prune(parentData.prunedDiscreteProbs); } return parentData; } }; - HybridPrunerData rootData(prunedDecisionTree, 0); + HybridPrunerData rootData(prunedDiscreteProbs, 0); { treeTraversal::no_op visitorPost; // Limits OpenMP threads since we're mixing TBB and OpenMP diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index f0d28e9f54..2b23ed4dbf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -98,7 +98,7 @@ static GaussianFactorGraphTree addGaussian( // TODO(dellaert): it's probably more efficient to first collect the discrete // keys, and then loop over all assignments to populate a vector. GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { - gttic(assembleGraphTree); + gttic_(assembleGraphTree); GaussianFactorGraphTree result; @@ -131,7 +131,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } } - gttoc(assembleGraphTree); + gttoc_(assembleGraphTree); return result; } @@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ // If any GaussianFactorGraph in the decision tree contains a nullptr, convert // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will -// otherwise create a GFG with a single (null) factor, which doesn't register as null. +// otherwise create a GFG with a single (null) factor, +// which doesn't register as null. GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { auto emptyGaussian = [](const GaussianFactorGraph &graph) { bool hasNull = @@ -230,26 +231,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors, return {nullptr, nullptr}; } -#ifdef HYBRID_TIMING - gttic_(hybrid_eliminate); -#endif - auto result = EliminatePreferCholesky(graph, frontalKeys); -#ifdef HYBRID_TIMING - gttoc_(hybrid_eliminate); -#endif - return result; }; // Perform elimination! DecisionTree eliminationResults(factorGraphTree, eliminate); -#ifdef HYBRID_TIMING - tictoc_print_(); -#endif - // Separate out decision tree into conditionals and remaining factors. const auto [conditionals, newFactors] = unzip(eliminationResults); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index f911b135b0..421e69aa05 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph public: using Base = HybridFactorGraph; using This = HybridGaussianFactorGraph; ///< this class - using BaseEliminateable = - EliminateableFactorGraph; ///< for elimination + ///< for elimination + using BaseEliminateable = EliminateableFactorGraph; using shared_ptr = std::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility @@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph /// @name Standard Interface /// @{ - using Base::error; // Expose error(const HybridValues&) method.. + /// Expose error(const HybridValues&) method. + using Base::error; /** * @brief Compute error for each discrete assignment,