diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index e582db0ffd..9948b0be63 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -22,10 +22,12 @@ #include #include +#include #include #include #include #include + namespace gtsam { /** diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 27e98fcdec..4266ace150 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -22,18 +22,15 @@ #include #include - -#include +#include #include -#include +#include #include +#include #include #include #include #include -#include -#include -#include namespace gtsam { @@ -251,22 +248,28 @@ namespace gtsam { label_ = f.label(); size_t count = f.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(f.branches_[i]->apply_f_op_g(g, op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = f.branches_[i]->apply_f_op_g(g, op); + push_back(std::move(newBranch)); + } } else if (g.label() > f.label()) { // f lower than g label_ = g.label(); size_t count = g.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(g.branches_[i]->apply_g_op_fC(f, op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = g.branches_[i]->apply_g_op_fC(f, op); + push_back(std::move(newBranch)); + } } else { // f same level as g label_ = f.label(); size_t count = f.nrChoices(); branches_.reserve(count); - for (size_t i = 0; i < count; i++) - push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op)); + for (size_t i = 0; i < count; i++) { + NodePtr newBranch = f.branches_[i]->apply_f_op_g(*g.branches_[i], op); + push_back(std::move(newBranch)); + } } } @@ -284,12 +287,12 @@ namespace gtsam { } /** add a branch: TODO merge into constructor */ - void push_back(const NodePtr& node) { + void push_back(NodePtr&& node) { // allSame_ is restricted to leaf nodes in a decision tree if (allSame_ && !branches_.empty()) { allSame_ = node->sameLeaf(*branches_.back()); } - branches_.push_back(node); + branches_.push_back(std::move(node)); } /// print (as a tree). @@ -497,9 +500,9 @@ namespace gtsam { DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { auto a = std::make_shared(label, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); - a->push_back(l1); - a->push_back(l2); - root_ = Choice::Unique(a); + a->push_back(std::move(l1)); + a->push_back(std::move(l2)); + root_ = Choice::Unique(std::move(a)); } /****************************************************************************/ @@ -510,11 +513,10 @@ namespace gtsam { "DecisionTree: binary constructor called with non-binary label"); auto a = std::make_shared(labelC.first, 2); NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); - a->push_back(l1); - a->push_back(l2); - root_ = Choice::Unique(a); + a->push_back(std::move(l1)); + a->push_back(std::move(l2)); + root_ = Choice::Unique(std::move(a)); } - /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, @@ -557,9 +559,7 @@ namespace gtsam { template DecisionTree::DecisionTree(const DecisionTree& other, Func Y_of_X) { - // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; - root_ = convertFrom(other.root_, L_of_L, Y_of_X); + root_ = convertFrom(other.root_, Y_of_X); } /****************************************************************************/ @@ -580,7 +580,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::compose( - Iterator begin, Iterator end, const L& label) const { + Iterator begin, Iterator end, const L& label) { // find highest label among branches std::optional highestLabel; size_t nrChoices = 0; @@ -598,8 +598,10 @@ namespace gtsam { // if label is already in correct order, just put together a choice on label if (!nrChoices || !highestLabel || label > *highestLabel) { auto choiceOnLabel = std::make_shared(label, end - begin); - for (Iterator it = begin; it != end; it++) - choiceOnLabel->push_back(it->root_); + for (Iterator it = begin; it != end; it++) { + NodePtr root = it->root_; + choiceOnLabel->push_back(std::move(root)); + } // If no reordering, no need to call Choice::Unique return choiceOnLabel; } else { @@ -618,7 +620,7 @@ namespace gtsam { } // We then recurse, for all values of the highest label NodePtr fi = compose(functions.begin(), functions.end(), label); - choiceOnHighestLabel->push_back(fi); + choiceOnHighestLabel->push_back(std::move(fi)); } return choiceOnHighestLabel; } @@ -648,7 +650,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::build( - It begin, It end, ValueIt beginY, ValueIt endY) const { + It begin, It end, ValueIt beginY, ValueIt endY) { // get crucial counts size_t nrChoices = begin->second; size_t size = endY - beginY; @@ -675,6 +677,7 @@ namespace gtsam { // Creates one tree (i.e.,function) for each choice of current key // by calling create recursively, and then puts them all together. std::vector functions; + functions.reserve(nrChoices); size_t split = size / nrChoices; for (size_t i = 0; i < nrChoices; i++, beginY += split) { NodePtr f = build(labelC, end, beginY, beginY + split); @@ -689,7 +692,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::create( - It begin, It end, ValueIt beginY, ValueIt endY) const { + It begin, It end, ValueIt beginY, ValueIt endY) { auto node = build(begin, end, beginY, endY); if (auto choice = std::dynamic_pointer_cast(node)) { return Choice::Unique(choice); @@ -698,17 +701,44 @@ namespace gtsam { } } + /****************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convertFrom( + const typename DecisionTree::NodePtr& f, + std::function Y_of_X) { + + // If leaf, apply unary conversion "op" and create a unique leaf. + using LXLeaf = typename DecisionTree::Leaf; + if (auto leaf = std::dynamic_pointer_cast(f)) { + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + } + + // Check if Choice + using LXChoice = typename DecisionTree::Choice; + auto choice = std::dynamic_pointer_cast(f); + if (!choice) throw std::invalid_argument( + "DecisionTree::convertFrom: Invalid NodePtr"); + + // Create a new Choice node with the same label + auto newChoice = std::make_shared(choice->label(), choice->nrChoices()); + + // Convert each branch recursively + for (auto&& branch : choice->branches()) { + newChoice->push_back(convertFrom(branch, Y_of_X)); + } + + return Choice::Unique(newChoice); + } + /****************************************************************************/ template template typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const { + std::function L_of_M, std::function Y_of_X) { using LY = DecisionTree; - // Ugliness below because apparently we can't have templated virtual - // functions. // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = std::dynamic_pointer_cast(f)) { @@ -718,19 +748,27 @@ namespace gtsam { // Check if Choice using MXChoice = typename DecisionTree::Choice; auto choice = std::dynamic_pointer_cast(f); - if (!choice) throw std::invalid_argument( - "DecisionTree::convertFrom: Invalid NodePtr"); + if (!choice) + throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr"); // get new label const M oldLabel = choice->label(); const L newLabel = L_of_M(oldLabel); - // put together via Shannon expansion otherwise not sorted. + // Shannon expansion in this context involves: + // 1. Creating separate subtrees (functions) for each possible value of the new label. + // 2. Combining these subtrees using the 'compose' method, which implements the expansion. + // This approach guarantees that the resulting tree maintains the correct variable ordering + // based on the new labels (L) after translation from the old labels (M). + // Simply creating a Choice node here would not work because it wouldn't account for the + // potentially new ordering of variables resulting from the label translation, + // which is crucial for maintaining consistency and efficiency in the converted tree. std::vector functions; for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } - return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); + return Choice::Unique( + LY::compose(functions.begin(), functions.end(), newLabel)); } /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 6d6179a7e6..6d8d865300 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -155,7 +154,7 @@ namespace gtsam { * and Y values */ template - NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const; + static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); /** Internal helper function to create from * keys, cardinalities, and Y values. @@ -163,7 +162,20 @@ namespace gtsam { * before we prune in a top-down fashion. */ template - NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); + + /** + * @brief Convert from a DecisionTree to DecisionTree. + * + * @tparam M The previous label type. + * @tparam X The previous value type. + * @param f The node pointer to the root of the previous DecisionTree. + * @param Y_of_X Functor to convert from value type X to type Y. + * @return NodePtr + */ + template + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function Y_of_X); /** * @brief Convert from a DecisionTree to DecisionTree. @@ -176,9 +188,9 @@ namespace gtsam { * @return NodePtr */ template - NodePtr convertFrom(const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const; + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X); public: /// @name Standard Constructors @@ -402,7 +414,7 @@ namespace gtsam { // internal use only template NodePtr - compose(Iterator begin, Iterator end, const L& label) const; + static compose(Iterator begin, Iterator end, const L& label); /// @}