Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decision Tree Improvements #1542

Merged
merged 16 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/HandleGeneralOptions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library,
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)
Expand Down
1 change: 1 addition & 0 deletions cmake/HandlePrintConfiguration.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")
Expand Down
3 changes: 3 additions & 0 deletions gtsam/config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#cmakedefine GTSAM_ROT3_EXPMAP
#endif

// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING

// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB

Expand Down
103 changes: 57 additions & 46 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
std::cout << s << " Leaf [" << nrAssignments() << "]"
<< valueFormatter(constant_) << std::endl;
}

/** Write graphviz format to stream `os`. */
Expand Down Expand Up @@ -136,7 +137,9 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
// fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_));
// TODO(Varun) nrAssignments setting is not correct.
// Depending on f and g, the nrAssignments can be different. This is a bug!
NodePtr h(new Leaf(op(fL.constant_, constant_), fL.nrAssignments()));
return h;
}

Expand Down Expand Up @@ -198,48 +201,57 @@ namespace gtsam {
#endif
}

/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef GTSAM_DT_NO_PRUNING
// If all the branches are the same, we can merge them into one
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];

size_t nrAssignments = 0;
for(auto branch: f->branches()) {
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
nrAssignments += leaf->nrAssignments();
}
}
NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;

} else
// Else we recurse
#endif
{

// Make non-const copy
auto ff = std::make_shared<Choice>(f->label(), f->nrChoices());
/**
* @brief Merge branches with equal leaf values for every choice node in a
* decision tree. If all branches are the same (i.e. have the same leaf
* value), replace the choice node with the equivalent leaf node.
*
* This function applies the branch merging (if enabled) recursively on the
* decision tree represented by the root node passed in as the argument. It
* recurses to the leaf nodes and merges branches with equal leaf values in
* a bottom-up fashion.
*
* Thus, if all branches of a choice node `f` are the same,
* just return a single branch at each recursion step.
*
* @param node The root node of the decision tree.
* @return NodePtr
*/
static NodePtr Unique(const NodePtr& node) {
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
// Choice node, we recurse!
// Make non-const copy so we can update
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());

// Iterate over all the branches
for (size_t i = 0; i < f->nrChoices(); i++) {
auto branch = f->branches_[i];
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
// Leaf node, simply assign
ff->push_back(branch);

} else if (auto choice =
std::dynamic_pointer_cast<const Choice>(branch)) {
// Choice node, we recurse
ff->push_back(Unique(choice));
}
for (size_t i = 0; i < choice->nrChoices(); i++) {
auto branch = choice->branches_[i];
f->push_back(Unique(branch));
}

return ff;
#ifdef GTSAM_DT_MERGING
// If all the branches are the same, we can merge them into one
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];

// Compute total number of assignments
size_t nrAssignments = 0;
for (auto branch : f->branches()) {
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(branch)) {
nrAssignments += leaf->nrAssignments();
}
}
NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(),
nrAssignments));
return newLeaf;
}
#endif
return f;
} else {
// Leaf node, return as is
return node;
}
}

Expand Down Expand Up @@ -486,13 +498,11 @@ namespace gtsam {
// DecisionTree
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {
}
DecisionTree<L, Y>::DecisionTree() {}

template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
root_(root) {
}
root_(root) {}

/****************************************************************************/
template<typename L, typename Y>
Expand Down Expand Up @@ -608,7 +618,8 @@ namespace gtsam {
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel);
// If no reordering, no need to call Choice::Unique
return choiceOnLabel;
} else {
// Set up a new choice on the highest label
auto choiceOnHighestLabel =
Expand Down Expand Up @@ -737,7 +748,7 @@ namespace gtsam {
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return LY::compose(functions.begin(), functions.end(), newLabel);
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
}

/****************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
const size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
Expand Down
21 changes: 20 additions & 1 deletion gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
//#define GTSAM_DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
Expand Down Expand Up @@ -179,7 +178,11 @@ TEST(ADT, joint) {
dot(joint, "Asia-ASTLBEX");
joint = apply(joint, pD, &mul);
dot(joint, "Asia-ASTLBEXD");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(346, muls);
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
Expand Down Expand Up @@ -240,7 +243,11 @@ TEST(ADT, inference) {
dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
Expand All @@ -258,7 +265,11 @@ TEST(ADT, inference) {
dot(marginal, "Joint-Sum-ADBLE");
marginal = marginal.combine(E, &add_);
dot(marginal, "Joint-Sum-ADBL");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(161, (long)adds);
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
Expand Down Expand Up @@ -296,7 +307,11 @@ TEST(ADT, factor_graph) {
fg = apply(fg, pX, &mul);
fg = apply(fg, pD, &mul);
dot(fg, "FactorGraph");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(158, (long)muls);
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
Expand All @@ -315,7 +330,11 @@ TEST(ADT, factor_graph) {
dot(fg, "Marginalized-3E");
fg = fg.combine(L, &add_);
dot(fg, "Marginalized-2L");
#ifdef GTSAM_DT_MERGING
LONGS_EQUAL(49, adds);
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
Expand Down
Loading