From f98b9223e8cbd8a39fee2d91d6459dae1542c3b7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 15:43:03 +0900 Subject: [PATCH 1/5] Make compose and convertFrom static --- gtsam/discrete/DecisionTree-inl.h | 28 ++++++++++++++++------------ gtsam/discrete/DecisionTree.h | 8 ++++---- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 27e98fcdec..8be5efaa69 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -580,7 +580,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::compose( - Iterator begin, Iterator end, const L& label) const { + Iterator begin, Iterator end, const L& label) { // find highest label among branches std::optional highestLabel; size_t nrChoices = 0; @@ -703,12 +703,9 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const { + std::function L_of_M, std::function Y_of_X) { using LY = DecisionTree; - // Ugliness below because apparently we can't have templated virtual - // functions. // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = std::dynamic_pointer_cast(f)) { @@ -718,19 +715,27 @@ namespace gtsam { // Check if Choice using MXChoice = typename DecisionTree::Choice; auto choice = std::dynamic_pointer_cast(f); - if (!choice) throw std::invalid_argument( - "DecisionTree::convertFrom: Invalid NodePtr"); + if (!choice) + throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr"); // get new label const M oldLabel = choice->label(); const L newLabel = L_of_M(oldLabel); - // put together via Shannon expansion otherwise not sorted. + // Shannon expansion in this context involves: + // 1. Creating separate subtrees (functions) for each possible value of the new label. + // 2. Combining these subtrees using the 'compose' method, which implements the expansion. + // This approach guarantees that the resulting tree maintains the correct variable ordering + // based on the new labels (L) after translation from the old labels (M). + // Simply creating a Choice node here would not work because it wouldn't account for the + // potentially new ordering of variables resulting from the label translation, + // which is crucial for maintaining consistency and efficiency in the converted tree. std::vector functions; for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } - return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); + return Choice::Unique( + LY::compose(functions.begin(), functions.end(), newLabel)); } /****************************************************************************/ @@ -740,9 +745,8 @@ namespace gtsam { * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it - * can have less than 8 leaves. For example, if a tree has all assignment - * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 - * assignments. + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. */ template struct Visit { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 6d6179a7e6..0d9db1fce1 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -176,9 +176,9 @@ namespace gtsam { * @return NodePtr */ template - NodePtr convertFrom(const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const; + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X); public: /// @name Standard Constructors @@ -402,7 +402,7 @@ namespace gtsam { // internal use only template NodePtr - compose(Iterator begin, Iterator end, const L& label) const; + static compose(Iterator begin, Iterator end, const L& label); /// @} From 6a5dd60d33d1f9ec78fd2e6c4136d4124e33bd9d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 15:51:21 +0900 Subject: [PATCH 2/5] Faster version of convertFrom when no label translation needed --- gtsam/discrete/DecisionTree-inl.h | 39 +++++++++++++++++++++++++++---- gtsam/discrete/DecisionTree.h | 13 +++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 8be5efaa69..6f19574fc8 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -557,9 +557,7 @@ namespace gtsam { template DecisionTree::DecisionTree(const DecisionTree& other, Func Y_of_X) { - // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; - root_ = convertFrom(other.root_, L_of_L, Y_of_X); + root_ = convertFrom(other.root_, Y_of_X); } /****************************************************************************/ @@ -698,6 +696,36 @@ namespace gtsam { } } + /****************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function Y_of_X) { + + // If leaf, apply unary conversion "op" and create a unique leaf. + using LXLeaf = typename DecisionTree::Leaf; + if (auto leaf = std::dynamic_pointer_cast(f)) { + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + } + + // Check if Choice + using LXChoice = typename DecisionTree::Choice; + auto choice = std::dynamic_pointer_cast(f); + if (!choice) throw std::invalid_argument( + "DecisionTree::convertFrom: Invalid NodePtr"); + + // Create a new Choice node with the same label + auto newChoice = std::make_shared(choice->label(), choice->nrChoices()); + + // Convert each branch recursively + for (auto&& branch : choice->branches()) { + newChoice->push_back(convertFrom(branch, Y_of_X)); + } + + return Choice::Unique(newChoice); + } + /****************************************************************************/ template template @@ -745,8 +773,9 @@ namespace gtsam { * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it - * can have <8 leaves. For example, if a tree has all assignment values as 1, - * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + * can have less than 8 leaves. For example, if a tree has all assignment + * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 + * assignments. */ template struct Visit { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 0d9db1fce1..c1d7ea05f7 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -165,6 +165,19 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function Y_of_X); + /** * @brief Convert from a DecisionTree to DecisionTree. * From ba7674d5fb6f45da39eedd2f034ed7652ee6f308 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 16:35:50 +0900 Subject: [PATCH 3/5] Move semantics --- gtsam/discrete/DecisionTree-inl.h | 53 +++++++++++++++++-------------- gtsam/discrete/DecisionTree.h | 1 - 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 6f19574fc8..7cb470c530 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -22,18 +22,15 @@ #include #include - -#include +#include #include -#include +#include #include +#include #include #include #include #include -#include -#include -#include namespace gtsam { @@ -251,22 +248,28 @@ namespace gtsam { label_ = f.label(); size_t count = f.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(f.branches_[i]->apply_f_op_g(g, op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = f.branches_[i]->apply_f_op_g(g, op); + push_back(std::move(newBranch)); + } } else if (g.label() > f.label()) { // f lower than g label_ = g.label(); size_t count = g.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(g.branches_[i]->apply_g_op_fC(f, op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = g.branches_[i]->apply_g_op_fC(f, op); + push_back(std::move(newBranch)); + } } else { // f same level as g label_ = f.label(); size_t count = f.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = f.branches_[i]->apply_f_op_g(*g.branches_[i], op); + push_back(std::move(newBranch)); + } } } @@ -284,12 +287,12 @@ namespace gtsam { } /** add a branch: TODO merge into constructor */ - void push_back(const NodePtr& node) { + void push_back(NodePtr&& node) { // allSame_ is restricted to leaf nodes in a decision tree if (allSame_ && !branches_.empty()) { allSame_ = node->sameLeaf(*branches_.back()); } - branches_.push_back(node); + branches_.push_back(std::move(node)); } /// print (as a tree). @@ -497,9 +500,9 @@ namespace gtsam { DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { auto a = std::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); - a->push_back(l1); - a->push_back(l2); - root_ = Choice::Unique(a); + a->push_back(std::move(l1)); + a->push_back(std::move(l2)); + root_ = Choice::Unique(std::move(a)); } /****************************************************************************/ @@ -510,11 +513,10 @@ namespace gtsam { "DecisionTree: binary constructor called with non-binary label"); auto a = std::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); - a->push_back(l1); - a->push_back(l2); - root_ = Choice::Unique(a); + a->push_back(std::move(l1)); + a->push_back(std::move(l2)); + root_ = Choice::Unique(std::move(a)); } - /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, @@ -596,8 +598,10 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { auto choiceOnLabel = std::make_shared(label, end - begin); - for (Iterator it = begin; it != end; it++) - choiceOnLabel->push_back(it->root_); + for (Iterator it = begin; it != end; it++) { + NodePtr root = it->root_; + choiceOnLabel->push_back(std::move(root)); + } // If no reordering, no need to call Choice::Unique return choiceOnLabel; } else { @@ -616,7 +620,7 @@ namespace gtsam { } // We then recurse, for all values of the highest label NodePtr fi = compose(functions.begin(), functions.end(), label); - choiceOnHighestLabel->push_back(fi); + choiceOnHighestLabel->push_back(std::move(fi)); } return choiceOnHighestLabel; } @@ -673,6 +677,7 @@ namespace gtsam { // Creates one tree (i.e.,function) for each choice of current key // by calling create recursively, and then puts them all together. std::vector functions; + functions.reserve(nrChoices); size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = build(labelC, end, beginY, beginY + split); diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index c1d7ea05f7..34c916d027 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include From 8cc5171cbc700382f22b1d5461ea99151407f4d8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 16:50:28 +0900 Subject: [PATCH 4/5] Two more methods are static --- gtsam/discrete/DecisionTree-inl.h | 4 ++-- gtsam/discrete/DecisionTree.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 7cb470c530..4266ace150 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -650,7 +650,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::build( - It begin, It end, ValueIt beginY, ValueIt endY) const { + It begin, It end, ValueIt beginY, ValueIt endY) { // get crucial counts size_t nrChoices = begin->second; size_t size = endY - beginY; @@ -692,7 +692,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::create( - It begin, It end, ValueIt beginY, ValueIt endY) const { + It begin, It end, ValueIt beginY, ValueIt endY) { auto node = build(begin, end, beginY, endY); if (auto choice = std::dynamic_pointer_cast(node)) { return Choice::Unique(choice); diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 34c916d027..6d8d865300 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -154,7 +154,7 @@ namespace gtsam { * and Y values */ template - NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const; + static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); /** Internal helper function to create from * keys, cardinalities, and Y values. @@ -162,7 +162,7 @@ namespace gtsam { * before we prune in a top-down fashion. */ template - NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); /** * @brief Convert from a DecisionTree to DecisionTree. From 935df2b90cf35251eaea31f6880c51a04e8c6cf0 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 22:36:21 +0900 Subject: [PATCH 5/5] Add missing header --- gtsam/discrete/AlgebraicDecisionTree.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index e582db0ffd..9948b0be63 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -22,10 +22,12 @@ #include #include +#include #include #include #include #include + namespace gtsam { /**