From a9ffbf5299c36f8a8e92d2335ed8e6f73fc30b12 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Oct 2024 15:09:21 -0400 Subject: [PATCH 1/8] new AlgebraicDecisionTree constructor --- gtsam/discrete/AlgebraicDecisionTree.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 6001b1983d..45b949d3c4 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -182,6 +182,21 @@ namespace gtsam { this->root_ = DecisionTree::convertFrom(other.root_, L_of_M, op); } + /** + * @brief Create from an arbitrary DecisionTree by operating on it + * with a functional `f`. + * + * @tparam X The type of the leaf of the original DecisionTree + * @tparam Func Type signature of functional `f`. + * @param other The original DecisionTree from which the + * AlgbraicDecisionTree is constructed. + * @param f Functional used to operate on + * the leaves of the input DecisionTree. + */ + template + AlgebraicDecisionTree(const DecisionTree& other, Func f) + : Base(other, f) {} + /** sum */ AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const { return this->apply(g, &Ring::add); From 024e50f9f77bf6ba1b8f2bf14d7f9b6939a9eac6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Oct 2024 15:09:33 -0400 Subject: [PATCH 2/8] normalize potentials --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 5c83fe5146..49f3be776b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -229,8 +229,13 @@ continuousElimination(const HybridGaussianFactorGraph &factors, } /* ************************************************************************ */ -/// Take negative log-values, shift them so that the minimum value is 0, and -/// then exponentiate to create a DecisionTreeFactor (not normalized yet!). +/** + * @brief Take negative log-values, shift them so that the minimum value is 0, + * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). + * + * @param errors DecisionTree of (unnormalized) errors. + * @return AlgebraicDecisionTree + */ static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { @@ -258,7 +263,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, if (!factor) return std::numeric_limits::infinity(); return scalar + factor->error(kEmpty); }; - DecisionTree errors(gmf->factors(), calculateError); + AlgebraicDecisionTree errors(gmf->factors(), calculateError); dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors)); } else if (auto orphan = dynamic_pointer_cast(f)) { @@ -307,7 +312,7 @@ static std::shared_ptr createDiscreteFactor( } }; - DecisionTree errors(eliminationResults, calculateError); + AlgebraicDecisionTree errors(eliminationResults, calculateError); return DiscreteFactorFromErrors(discreteSeparator, errors); } From ac52be9cf07141e530c7dc909df55018ab4b6ed2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Oct 2024 15:09:38 -0400 Subject: [PATCH 3/8] update test --- gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 01294b28c6..4b91d091d8 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -333,7 +333,7 @@ TEST(HybridBayesNet, Switching) { CHECK(phi_x1); EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0 // We can't really check the error of the decision tree factor phi_x1, because - // the continuous factor whose error(kEmpty) we need is not available.. + // the continuous factor whose error(kEmpty) we need is not available. // Now test full elimination of the graph: auto hybridBayesNet = graph.eliminateSequential(); From b79c69b4084ff4cf292103a56e0c5b3acfe7428d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 9 Oct 2024 15:10:13 -0400 Subject: [PATCH 4/8] small improvements --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 49f3be776b..8e6123f10e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -234,13 +234,13 @@ continuousElimination(const HybridGaussianFactorGraph &factors, * and then exponentiate to create a DecisionTreeFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. - * @return AlgebraicDecisionTree + * @return DecisionTreeFactor::shared_ptr */ static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); - AlgebraicDecisionTree potentials = DecisionTree( + AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); return std::make_shared(discreteKeys, potentials); } From e9bf802d788815e9e6bccaea66059fdf18bc0502 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 9 Oct 2024 15:24:03 -0400 Subject: [PATCH 5/8] more simplification --- gtsam/hybrid/HybridGaussianFactor.cpp | 3 +-- gtsam/hybrid/HybridNonlinearFactor.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index b31fdca200..fd9bd2fd4d 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -189,8 +189,7 @@ AlgebraicDecisionTree HybridGaussianFactor::errorTree( auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) { return PotentiallyPrunedComponentError(pair, continuousValues); }; - DecisionTree error_tree(factors_, errorFunc); - return error_tree; + return {factors_, errorFunc}; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index 9378d07fe2..6ffb955117 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -100,8 +100,7 @@ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( auto [factor, val] = f; return factor->error(continuousValues) + val; }; - DecisionTree result(factors_, errorFunc); - return result; + return {factors_, errorFunc}; } /* *******************************************************************************/ From acda56d67b55eda015636ae1d164906768f0d844 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 9 Oct 2024 18:34:56 -0400 Subject: [PATCH 6/8] normalize no longer takes explicit sum, so that it normalizes correctly --- gtsam/discrete/AlgebraicDecisionTree.h | 5 +---- gtsam/discrete/tests/testAlgebraicDecisionTree.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 45b949d3c4..e582db0ffd 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -234,12 +234,9 @@ namespace gtsam { * @brief Helper method to perform normalization such that all leaves in the * tree sum to 1 * - * @param sum * @return AlgebraicDecisionTree */ - AlgebraicDecisionTree normalize(double sum) const { - return this->apply([&sum](const double& x) { return x / sum; }); - } + AlgebraicDecisionTree normalize() const { return (*this) / this->sum(); } /// Find the minimum values amongst all leaves double min() const { diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index bf728695c9..a5e46d538f 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -562,7 +562,7 @@ TEST(ADT, Sum) { TEST(ADT, Normalize) { ADT a = exampleADT(); double sum = a.sum(); - auto actual = a.normalize(sum); + auto actual = a.normalize(); DiscreteKey A(0, 2), B(1, 3), C(2, 2); DiscreteKeys keys = DiscreteKeys{A, B, C}; From 878c626a96cbabdeff0fe2b5665548bd65576357 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 10 Oct 2024 09:32:55 -0400 Subject: [PATCH 7/8] fix test TODOs --- .../tests/testHybridGaussianProductFactor.cpp | 15 ++++++++++----- .../tests/testHybridNonlinearFactorGraph.cpp | 2 -- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp b/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp index f41c5f0aa2..3a4a6c1f41 100644 --- a/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianProductFactor.cpp @@ -128,7 +128,10 @@ TEST(HybridGaussianProductFactor, AsProductFactor) { EXPECT(actual.first.at(0) == f10); EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9); - // TODO(Frank): when killed hiding, f11 should also be there + mode[m1.first] = 1; + actual = product(mode); + EXPECT(actual.first.at(0) == f11); + EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9); } /* ************************************************************************* */ @@ -145,7 +148,10 @@ TEST(HybridGaussianProductFactor, AddOne) { EXPECT(actual.first.at(0) == f10); EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9); - // TODO(Frank): when killed hiding, f11 should also be there + mode[m1.first] = 1; + actual = product(mode); + EXPECT(actual.first.at(0) == f11); + EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9); } /* ************************************************************************* */ @@ -166,9 +172,8 @@ TEST(HybridGaussianProductFactor, AddTwo) { EXPECT_DOUBLES_EQUAL(10 + 20, actual00.second, 1e-9); auto actual12 = product({{M(1), 1}, {M(2), 2}}); - // TODO(Frank): when killed hiding, these should also equal: - // EXPECT(actual12.first.at(0) == f11); - // EXPECT(actual12.first.at(1) == f22); + EXPECT(actual12.first.at(0) == f11); + EXPECT(actual12.first.at(1) == f22); EXPECT_DOUBLES_EQUAL(11 + 22, actual12.second, 1e-9); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index bbf427ecb2..e77476e258 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -973,8 +973,6 @@ TEST(HybridNonlinearFactorGraph, DifferentMeans) { VectorValues cont0 = bn->optimize(dv0); double error0 = bn->error(HybridValues(cont0, dv0)); - // TODO(Varun) Perform importance sampling to estimate error? - // regression EXPECT_DOUBLES_EQUAL(0.69314718056, error0, 1e-9); From 32108ab8ecdf9e6a85617166824486b681ad5b87 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 10 Oct 2024 09:33:07 -0400 Subject: [PATCH 8/8] remove unnecessary TODOs --- gtsam/discrete/CMakeLists.txt | 1 - gtsam/hybrid/CMakeLists.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/gtsam/discrete/CMakeLists.txt b/gtsam/discrete/CMakeLists.txt index d78dff34f6..1c6aa97472 100644 --- a/gtsam/discrete/CMakeLists.txt +++ b/gtsam/discrete/CMakeLists.txt @@ -1,7 +1,6 @@ # Install headers set(subdir discrete) file(GLOB discrete_headers "*.h") -# FIXME: exclude headers install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete) # Add all tests diff --git a/gtsam/hybrid/CMakeLists.txt b/gtsam/hybrid/CMakeLists.txt index f1cfcd5c4b..cdada00dde 100644 --- a/gtsam/hybrid/CMakeLists.txt +++ b/gtsam/hybrid/CMakeLists.txt @@ -1,7 +1,6 @@ # Install headers set(subdir hybrid) file(GLOB hybrid_headers "*.h") -# FIXME: exclude headers install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid) # Add all tests