Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GTSAM_DT_MERGING Flag #1501

Merged
merged 53 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e114e9f
add nrAssignments method for DecisionTree
varunagrawal Mar 26, 2023
6aa7d66
add unit test showing issue with nrAssignments
varunagrawal Mar 26, 2023
1818695
updated docs to better describe nrAssignments
varunagrawal Mar 29, 2023
0cd36db
Merge branch 'develop' into fix-1496
varunagrawal Jun 7, 2023
73b563a
WIP for debugging nrAssignments issue
varunagrawal Jun 8, 2023
8a8f146
update Unique to be recursive
varunagrawal Jun 8, 2023
ff1ea32
remove unnecessary code
varunagrawal Jun 8, 2023
dbd0a7d
re-enable DecisionTree tests
varunagrawal Jun 8, 2023
68cb724
add new build method to replace create, and let create call Unique
varunagrawal Jun 8, 2023
be70ffc
remove excessive Unique call to improve efficiency
varunagrawal Jun 8, 2023
c3090f0
cleanup
varunagrawal Jun 8, 2023
70ffbf3
mark nrAssignments as const
varunagrawal Jun 8, 2023
2352043
rename GTSAM_DT_NO_PRUNING to GTSAM_DT_NO_MERGING to help with disamb…
varunagrawal Jun 8, 2023
2998820
bottom-up Unique method that works much, much better
varunagrawal Jun 8, 2023
a66e270
print nrAssignments when printing decision trees
varunagrawal Jun 8, 2023
d74e41a
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 9, 2023
39cf348
Merge branch 'develop' into fix-1496
varunagrawal Jun 9, 2023
0cb1316
Merge branch 'fix-1496' into decisiontree-improvements
varunagrawal Jun 9, 2023
76568f2
formatting
varunagrawal Jun 9, 2023
29c1816
change to GTSAM_DT_MERGING and expose via CMake
varunagrawal Jun 10, 2023
8959982
remove extra calls to Unique
varunagrawal Jun 14, 2023
88ab371
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 22, 2023
7af8e66
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 22, 2023
c605a5b
Merge branch 'develop' into fix-1496
varunagrawal Jun 26, 2023
b37fc3f
update DecisionTree::nrAssignments docstring
varunagrawal Jun 26, 2023
3d7163a
Merge branch 'fix-1496' into decisiontree-improvements
varunagrawal Jun 26, 2023
b24f20a
fix tests to work when GTSAM_DT_MERGING=OFF
varunagrawal Jun 26, 2023
8ffddc4
print GTSAM_DT_MERGING cmake config
varunagrawal Jun 26, 2023
e5fea0d
update docstring
varunagrawal Jun 26, 2023
9b7f4b3
fix test case
varunagrawal Jun 28, 2023
8c38e45
enumerate all assignments for computing probabilities to prune
varunagrawal Jun 28, 2023
b86696a
Merge pull request #1542 from borglab/decisiontree-improvements
varunagrawal Jun 28, 2023
647d3c0
remove nrAssignments from the DecisionTree
varunagrawal Jun 28, 2023
2db0828
Revert "remove nrAssignments from the DecisionTree"
varunagrawal Jul 10, 2023
b7deefd
Revert "enumerate all assignments for computing probabilities to prune"
varunagrawal Jul 10, 2023
e5a7bac
Merge pull request #1555 from borglab/remove-nrAssignments
varunagrawal Jul 10, 2023
3fe9f1a
Merge branch 'develop' into fix-1496
varunagrawal Jul 18, 2023
ff7c368
Merge branch 'hybrid-tablefactor-2' into fix-1496
varunagrawal Jul 19, 2023
cf6c1ca
fix tests
varunagrawal Jul 19, 2023
372e703
Merge branch 'develop' into fix-1496
varunagrawal Jul 19, 2023
1dfb388
fix odd behavior in nrAssignments
varunagrawal Jul 20, 2023
ea24a2c
park changes so I can come back to them later
varunagrawal Jul 20, 2023
369d08b
Merge branch 'develop' into fix-1496
varunagrawal Jul 28, 2023
b35fb0f
update tests
varunagrawal Jul 28, 2023
4e9d849
remove prints
varunagrawal Jul 28, 2023
4580c51
undo change
varunagrawal Jul 28, 2023
8cb33dd
remove make_unique flag
varunagrawal Jul 29, 2023
94d737e
remove printing
varunagrawal Jul 29, 2023
4386c51
remove nrAssignments from DecisionTree
varunagrawal Nov 6, 2023
ecd6450
Merge branch 'develop' into fix-1496
varunagrawal Nov 6, 2023
c4d11c4
fix unittest assertion deprecation
varunagrawal Nov 6, 2023
9b67c3a
Merge branch 'develop' into remove-nrAssignments
varunagrawal Nov 6, 2023
fe81362
Merge branch 'fix-1496' into remove-nrAssignments
varunagrawal Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/HandleGeneralOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library,
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
Expand Down
1 change: 1 addition & 0 deletions cmake/HandlePrintConfiguration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
Expand Down
3 changes: 3 additions & 0 deletions gtsam/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#cmakedefine GTSAM_ROT3_EXPMAP
#endif

// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING

// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

Expand Down
142 changes: 81 additions & 61 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,17 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;

/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;

/// Default constructor for serialization.
Leaf() {}

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
Leaf(const Y& constant) : constant_(constant) {}

/// Return the constant
const Y& constant() const {
return constant_;
}

/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }

/// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_;
Expand All @@ -93,8 +84,7 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
}

/** Write graphviz format to stream `os`. */
Expand All @@ -114,14 +104,14 @@ namespace gtsam {

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_), nrAssignments_));
NodePtr f(new Leaf(op(constant_)));
return f;
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
NodePtr f(new Leaf(op(assignment, constant_)));
return f;
}

Expand All @@ -137,7 +127,7 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
NodePtr h(new Leaf(op(fL.constant_, constant_)));
return h;
}

Expand All @@ -148,7 +138,7 @@ namespace gtsam {

/** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant(), nrAssignments()));
return NodePtr(new Leaf(constant()));
}

bool isLeaf() const override { return true; }
Expand All @@ -163,7 +153,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
#endif
}; // Leaf
Expand Down Expand Up @@ -199,26 +188,50 @@ namespace gtsam {
#endif
}

/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef GTSAM_DT_NO_PRUNING
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
/**
* @brief Merge branches with equal leaf values for every choice node in a
* decision tree. If all branches are the same (i.e. have the same leaf
* value), replace the choice node with the equivalent leaf node.
*
* This function applies the branch merging (if enabled) recursively on the
* decision tree represented by the root node passed in as the argument. It
* recurses to the leaf nodes and merges branches with equal leaf values in
* a bottom-up fashion.
*
* Thus, if all branches of a choice node `f` are the same,
* just return a single branch at each recursion step.
*
* @param node The root node of the decision tree.
* @return NodePtr
*/
static NodePtr Unique(const NodePtr& node) {
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
// Choice node, we recurse!
// Make non-const copy so we can update
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());

// Iterate over all the branches
for (size_t i = 0; i < choice->nrChoices(); i++) {
auto branch = choice->branches_[i];
f->push_back(Unique(branch));
}

size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
#ifdef GTSAM_DT_MERGING
// If all the branches are the same, we can merge them into one
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];

NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf;
}
NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
} else
#endif
return f;
} else {
// Leaf node, return as is
return node;
}
}

bool isLeaf() const override { return false; }
Expand Down Expand Up @@ -439,8 +452,10 @@ namespace gtsam {

// second case, not label of interest, just recurse
auto r = std::make_shared<Choice>(label_, branches_.size());
for (auto&& branch : branches_)
for (auto&& branch : branches_) {
r->push_back(branch->choose(label, index));
}

return Unique(r);
}

Expand All @@ -464,13 +479,11 @@ namespace gtsam {
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {
}
DecisionTree<L, Y>::DecisionTree() {}

template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
root_(root) {
}
root_(root) {}

/****************************************************************************/
template<typename L, typename Y>
Expand Down Expand Up @@ -586,7 +599,8 @@ namespace gtsam {
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel);
// If no reordering, no need to call Choice::Unique
return choiceOnLabel;
} else {
// Set up a new choice on the highest label
auto choiceOnHighestLabel =
Expand All @@ -605,21 +619,21 @@ namespace gtsam {
NodePtr fi = compose(functions.begin(), functions.end(), label);
choiceOnHighestLabel->push_back(fi);
}
return Choice::Unique(choiceOnHighestLabel);
return choiceOnHighestLabel;
}
}

/****************************************************************************/
// "create" is a bit of a complicated thing, but very useful.
// "build" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows:
// and builds a decision tree, as follows:
// - if there is only one label, creates a choice node with values in leaves
// - otherwise, it evenly splits up the range of values and creates a tree for
// each sub-range, and assigns that tree to first label's choices
// Example:
// create([B A],[1 2 3 4]) would call
// create([A],[1 2])
// create([A],[3 4])
// build([B A],[1 2 3 4]) would call
// build([A],[1 2])
// build([A],[3 4])
// and produce
// B=0
// A=0: 1
Expand All @@ -632,7 +646,7 @@ namespace gtsam {
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts
size_t nrChoices = begin->second;
Expand All @@ -650,9 +664,10 @@ namespace gtsam {
throw std::invalid_argument("DecisionTree::create invalid argument");
}
auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++)
for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y)));
return Choice::Unique(choice);
}
return choice;
}

// Recursive case: perform "Shannon expansion"
Expand All @@ -661,12 +676,27 @@ namespace gtsam {
std::vector<DecisionTree> functions;
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
functions.emplace_back(f);
}
return compose(functions.begin(), functions.end(), begin->first);
}

/****************************************************************************/
// Top-level factory method, which takes a range of labels and a corresponding
// range of values, and creates a decision tree.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
}
}

/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
Expand All @@ -681,7 +711,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments()));
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

// Check if Choice
Expand All @@ -699,7 +729,7 @@ namespace gtsam {
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
}

/****************************************************************************/
Expand Down Expand Up @@ -828,16 +858,6 @@ namespace gtsam {
return total;
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
48 changes: 10 additions & 38 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,19 @@ namespace gtsam {
NodePtr root_;

protected:
/**
/**
* Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
template <typename It, typename ValueIt>
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;

/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
*/
template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;

/**
Expand Down Expand Up @@ -320,42 +328,6 @@ namespace gtsam {
/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
Loading