Skip to content

Commit

Permalink
Merge pull request #1875 from borglab/feature/fasterDT
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 15, 2024
2 parents e6dfa7b + 935df2b commit db353a5
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 45 deletions.
2 changes: 2 additions & 0 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
#include <gtsam/discrete/DecisionTree-inl.h>

#include <algorithm>
#include <limits>
#include <map>
#include <string>
#include <iomanip>
#include <vector>

namespace gtsam {

/**
Expand Down
114 changes: 76 additions & 38 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
#include <gtsam/discrete/DecisionTree.h>

#include <algorithm>

#include <cmath>
#include <cassert>
#include <fstream>
#include <list>
#include <iterator>
#include <map>
#include <optional>
#include <set>
#include <sstream>
#include <string>
#include <vector>
#include <optional>
#include <cassert>
#include <iterator>

namespace gtsam {

Expand Down Expand Up @@ -251,22 +248,28 @@ namespace gtsam {
label_ = f.label();
size_t count = f.nrChoices();
branches_.reserve(count);
for (size_t i = 0; i < count; i++)
push_back(f.branches_[i]->apply_f_op_g(g, op));
for (size_t i = 0; i < count; i++) {
NodePtr newBranch = f.branches_[i]->apply_f_op_g(g, op);
push_back(std::move(newBranch));
}
} else if (g.label() > f.label()) {
// f lower than g
label_ = g.label();
size_t count = g.nrChoices();
branches_.reserve(count);
for (size_t i = 0; i < count; i++)
push_back(g.branches_[i]->apply_g_op_fC(f, op));
for (size_t i = 0; i < count; i++) {
NodePtr newBranch = g.branches_[i]->apply_g_op_fC(f, op);
push_back(std::move(newBranch));
}
} else {
// f same level as g
label_ = f.label();
size_t count = f.nrChoices();
branches_.reserve(count);
for (size_t i = 0; i < count; i++)
push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
for (size_t i = 0; i < count; i++) {
NodePtr newBranch = f.branches_[i]->apply_f_op_g(*g.branches_[i], op);
push_back(std::move(newBranch));
}
}
}

Expand All @@ -284,12 +287,12 @@ namespace gtsam {
}

/** add a branch: TODO merge into constructor */
void push_back(const NodePtr& node) {
void push_back(NodePtr&& node) {
// allSame_ is restricted to leaf nodes in a decision tree
if (allSame_ && !branches_.empty()) {
allSame_ = node->sameLeaf(*branches_.back());
}
branches_.push_back(node);
branches_.push_back(std::move(node));
}

/// print (as a tree).
Expand Down Expand Up @@ -497,9 +500,9 @@ namespace gtsam {
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = std::make_shared<Choice>(label, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);
a->push_back(std::move(l1));
a->push_back(std::move(l2));
root_ = Choice::Unique(std::move(a));
}

/****************************************************************************/
Expand All @@ -510,11 +513,10 @@ namespace gtsam {
"DecisionTree: binary constructor called with non-binary label");
auto a = std::make_shared<Choice>(labelC.first, 2);
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
a->push_back(l1);
a->push_back(l2);
root_ = Choice::Unique(a);
a->push_back(std::move(l1));
a->push_back(std::move(l2));
root_ = Choice::Unique(std::move(a));
}

/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
Expand Down Expand Up @@ -557,9 +559,7 @@ namespace gtsam {
template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) {
// Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
root_ = convertFrom<X>(other.root_, Y_of_X);
}

/****************************************************************************/
Expand All @@ -580,7 +580,7 @@ namespace gtsam {
template <typename L, typename Y>
template <typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
Iterator begin, Iterator end, const L& label) {
// find highest label among branches
std::optional<L> highestLabel;
size_t nrChoices = 0;
Expand All @@ -598,8 +598,10 @@ namespace gtsam {
// if label is already in correct order, just put together a choice on label
if (!nrChoices || !highestLabel || label > *highestLabel) {
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_);
for (Iterator it = begin; it != end; it++) {
NodePtr root = it->root_;
choiceOnLabel->push_back(std::move(root));
}
// If no reordering, no need to call Choice::Unique
return choiceOnLabel;
} else {
Expand All @@ -618,7 +620,7 @@ namespace gtsam {
}
// We then recurse, for all values of the highest label
NodePtr fi = compose(functions.begin(), functions.end(), label);
choiceOnHighestLabel->push_back(fi);
choiceOnHighestLabel->push_back(std::move(fi));
}
return choiceOnHighestLabel;
}
Expand Down Expand Up @@ -648,7 +650,7 @@ namespace gtsam {
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
It begin, It end, ValueIt beginY, ValueIt endY) const {
It begin, It end, ValueIt beginY, ValueIt endY) {
// get crucial counts
size_t nrChoices = begin->second;
size_t size = endY - beginY;
Expand All @@ -675,6 +677,7 @@ namespace gtsam {
// Creates one tree (i.e.,function) for each choice of current key
// by calling create recursively, and then puts them all together.
std::vector<DecisionTree> functions;
functions.reserve(nrChoices);
size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
Expand All @@ -689,7 +692,7 @@ namespace gtsam {
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
It begin, It end, ValueIt beginY, ValueIt endY) {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
return Choice::Unique(choice);
Expand All @@ -698,17 +701,44 @@ namespace gtsam {
}
}

/****************************************************************************/
template <typename L, typename Y>
template <typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<L, X>::NodePtr& f,
std::function<Y(const X&)> Y_of_X) {

// If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

// Check if Choice
using LXChoice = typename DecisionTree<L, X>::Choice;
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::convertFrom: Invalid NodePtr");

// Create a new Choice node with the same label
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());

// Convert each branch recursively
for (auto&& branch : choice->branches()) {
newChoice->push_back(convertFrom<X>(branch, Y_of_X));
}

return Choice::Unique(newChoice);
}

/****************************************************************************/
template <typename L, typename Y>
template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const {
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
using LY = DecisionTree<L, Y>;

// Ugliness below because apparently we can't have templated virtual
// functions.
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
Expand All @@ -718,19 +748,27 @@ namespace gtsam {
// Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::convertFrom: Invalid NodePtr");
if (!choice)
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");

// get new label
const M oldLabel = choice->label();
const L newLabel = L_of_M(oldLabel);

// put together via Shannon expansion otherwise not sorted.
// Shannon expansion in this context involves:
// 1. Creating separate subtrees (functions) for each possible value of the new label.
// 2. Combining these subtrees using the 'compose' method, which implements the expansion.
// This approach guarantees that the resulting tree maintains the correct variable ordering
// based on the new labels (L) after translation from the old labels (M).
// Simply creating a Choice node here would not work because it wouldn't account for the
// potentially new ordering of variables resulting from the label translation,
// which is crucial for maintaining consistency and efficiency in the converted tree.
std::vector<LY> functions;
for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
}
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
return Choice::Unique(
LY::compose(functions.begin(), functions.end(), newLabel));
}

/****************************************************************************/
Expand Down
26 changes: 19 additions & 7 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <iostream>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -155,15 +154,28 @@ namespace gtsam {
* and Y values
*/
template <typename It, typename ValueIt>
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);

/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
*/
template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);

/**
* @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
*
* @tparam M The previous label type.
* @tparam X The previous value type.
* @param f The node pointer to the root of the previous DecisionTree.
* @param Y_of_X Functor to convert from value type X to type Y.
* @return NodePtr
*/
template <typename X>
static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
std::function<Y(const X&)> Y_of_X);

/**
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Expand All @@ -176,9 +188,9 @@ namespace gtsam {
* @return NodePtr
*/
template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const;
static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X);

public:
/// @name Standard Constructors
Expand Down Expand Up @@ -402,7 +414,7 @@ namespace gtsam {

// internal use only
template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const;
static compose(Iterator begin, Iterator end, const L& label);

/// @}

Expand Down

0 comments on commit db353a5

Please sign in to comment.