Skip to content

Commit

Permalink
Merge pull request #1501 from borglab/fix-1496
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Nov 11, 2023
2 parents ba23e45 + fe81362 commit 1121ece
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 178 deletions.
1 change: 1 addition & 0 deletions cmake/HandleGeneralOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library,
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
Expand Down
1 change: 1 addition & 0 deletions cmake/HandlePrintConfiguration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
Expand Down
3 changes: 3 additions & 0 deletions gtsam/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#cmakedefine GTSAM_ROT3_EXPMAP
#endif

// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING

// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

Expand Down
142 changes: 81 additions & 61 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,17 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;

/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;

/// Default constructor for serialization.
Leaf() {}

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
Leaf(const Y& constant) : constant_(constant) {}

/// Return the constant
const Y& constant() const {
return constant_;
}

/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }

/// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_;
Expand All @@ -93,8 +84,7 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
}

/** Write graphviz format to stream `os`. */
Expand All @@ -114,14 +104,14 @@ namespace gtsam {

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_), nrAssignments_));
NodePtr f(new Leaf(op(constant_)));
return f;
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
NodePtr f(new Leaf(op(assignment, constant_)));
return f;
}

Expand All @@ -137,7 +127,7 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
NodePtr h(new Leaf(op(fL.constant_, constant_)));
return h;
}

Expand All @@ -148,7 +138,7 @@ namespace gtsam {

/** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant(), nrAssignments()));
return NodePtr(new Leaf(constant()));
}

bool isLeaf() const override { return true; }
Expand All @@ -163,7 +153,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
#endif
}; // Leaf
Expand Down Expand Up @@ -199,26 +188,50 @@ namespace gtsam {
#endif
}

/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef GTSAM_DT_NO_PRUNING
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
/**
* @brief Merge branches with equal leaf values for every choice node in a
* decision tree. If all branches are the same (i.e. have the same leaf
* value), replace the choice node with the equivalent leaf node.
*
* This function applies the branch merging (if enabled) recursively on the
* decision tree represented by the root node passed in as the argument. It
* recurses to the leaf nodes and merges branches with equal leaf values in
* a bottom-up fashion.
*
* Thus, if all branches of a choice node `f` are the same,
* just return a single branch at each recursion step.
*
* @param node The root node of the decision tree.
* @return NodePtr
*/
static NodePtr Unique(const NodePtr& node) {
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
// Choice node, we recurse!
// Make non-const copy so we can update
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());

// Iterate over all the branches
for (size_t i = 0; i < choice->nrChoices(); i++) {
auto branch = choice->branches_[i];
f->push_back(Unique(branch));
}

size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
#ifdef GTSAM_DT_MERGING
// If all the branches are the same, we can merge them into one
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];

NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf;
}
NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
#endif
return f;
} else {
// Leaf node, return as is
return node;
}
}

bool isLeaf() const override { return false; }
Expand Down Expand Up @@ -439,8 +452,10 @@ namespace gtsam {

// second case, not label of interest, just recurse
auto r = std::make_shared<Choice>(label_, branches_.size());
for (auto&& branch : branches_)
for (auto&& branch : branches_) {
r->push_back(branch->choose(label, index));
}

return Unique(r);
}

Expand All @@ -464,13 +479,11 @@ namespace gtsam {
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {
}
DecisionTree<L, Y>::DecisionTree() {}

template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
root_(root) {
}
root_(root) {}

/****************************************************************************/
template<typename L, typename Y>
Expand Down Expand Up @@ -586,7 +599,8 @@ namespace gtsam {
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel);
// If no reordering, no need to call Choice::Unique
return choiceOnLabel;
} else {
// Set up a new choice on the highest label
auto choiceOnHighestLabel =
Expand All @@ -605,21 +619,21 @@ namespace gtsam {
NodePtr fi = compose(functions.begin(), functions.end(), label);
choiceOnHighestLabel->push_back(fi);
}
return Choice::Unique(choiceOnHighestLabel);
return choiceOnHighestLabel;
}
}

/****************************************************************************/
// "create" is a bit of a complicated thing, but very useful.
// "build" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows:
// and builds a decision tree, as follows:
// - if there is only one label, creates a choice node with values in leaves
// - otherwise, it evenly splits up the range of values and creates a tree for
// each sub-range, and assigns that tree to first label's choices
// Example:
// create([B A],[1 2 3 4]) would call
// create([A],[1 2])
// create([A],[3 4])
// build([B A],[1 2 3 4]) would call
// build([A],[1 2])
// build([A],[3 4])
// and produce
// B=0
// A=0: 1
Expand All @@ -632,7 +646,7 @@ namespace gtsam {
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts
size_t nrChoices = begin->second;
Expand All @@ -650,9 +664,10 @@ namespace gtsam {
throw std::invalid_argument("DecisionTree::create invalid argument");
}
auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++)
for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y)));
return Choice::Unique(choice);
}
return choice;
}

// Recursive case: perform "Shannon expansion"
Expand All @@ -661,12 +676,27 @@ namespace gtsam {
std::vector<DecisionTree> functions;
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
functions.emplace_back(f);
}
return compose(functions.begin(), functions.end(), begin->first);
}

/****************************************************************************/
// Top-level factory method, which takes a range of labels and a corresponding
// range of values, and creates a decision tree.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
}
}

/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
Expand All @@ -681,7 +711,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

// Check if Choice
Expand All @@ -699,7 +729,7 @@ namespace gtsam {
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
}

/****************************************************************************/
Expand Down Expand Up @@ -828,16 +858,6 @@ namespace gtsam {
return total;
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
48 changes: 10 additions & 38 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,19 @@ namespace gtsam {
NodePtr root_;

protected:
/**
/**
* Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
template <typename It, typename ValueIt>
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;

/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
*/
template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;

/**
Expand Down Expand Up @@ -320,42 +328,6 @@ namespace gtsam {
/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
Loading

0 comments on commit 1121ece

Please sign in to comment.