Skip to content

Commit

Permalink
Merge pull request #1669 from borglab/discrete-error
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 7, 2024
2 parents 6b098c7 + bc3b96a commit 42b5218
Show file tree
Hide file tree
Showing 21 changed files with 97 additions and 25 deletions.
16 changes: 16 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,22 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
// Get all possible assignments
DiscreteKeys dkeys = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);

// Construct vector with error values
std::vector<double> errors;
for (const auto& assignment : assignments) {
errors.push_back(error(assignment));
}
return AlgebraicDecisionTree<Key>(dkeys, errors);
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ namespace gtsam {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}

private:
Expand Down
15 changes: 10 additions & 5 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/base/Testable.h>

#include <string>
namespace gtsam {
Expand All @@ -35,7 +36,7 @@ class HybridValues;
*
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {
class GTSAM_EXPORT DiscreteFactor : public Factor {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
Expand Down Expand Up @@ -103,15 +104,19 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
*/
double error(const HybridValues& c) const override;

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
/// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;

/// Multiply in a DecisionTreeFactor and return the result as
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
/// @name Wrapper support
/// @{

/// Translation table from values to strings.
using Names = DiscreteValues::Names;

Expand Down Expand Up @@ -175,4 +180,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
std::vector<double> expNormalize(const std::vector<double> &logProbs);


}// namespace gtsam
} // namespace gtsam
5 changes: 5 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ double TableFactor::error(const HybridValues& values) const {
return error(values.discrete());
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
return toDecisionTreeFactor().errorTree();
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
double error(const HybridValues& values) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override;

/// @}
};

Expand Down
18 changes: 18 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,24 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);

// Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");

auto errors = f.errorTree();
// regression
AlgebraicDecisionTree<Key> expected(
{X, Y, Z},
vector<double>{-0.69314718, -1.6094379, -1.0986123, -1.7917595,
-1.3862944, -1.9459101, -3.2188758, -4.0073332, -3.5553481,
-4.1743873, -3.8066625, -4.3174881});
EXPECT(assert_equal(expected, errors, 1e-6));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,14 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class GTSAM_EXPORT GaussianMixture
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
AlgebraicDecisionTree<Key> GaussianMixtureFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
}

/* *******************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;

/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

Expand All @@ -431,7 +431,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(

if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + gaussianMixture->error(continuousValues);
error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
Expand Down Expand Up @@ -460,7 +460,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/MixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factor, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const sharedFactor& factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
DecisionTree<Key, double> result(factors_, errorFunc);
return result;
}

/**
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) {
/// Check error.
TEST(GaussianMixture, Error) {
using namespace equal_constants;
auto actual = mixture.error(vv);
auto actual = mixture.errorTree(vv);

// Check result.
std::vector<DiscreteKey> discrete_keys = {mode};
Expand Down Expand Up @@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) {
std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv),
conditionals[1]->likelihood(vv)->error(vv)};
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6));
EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6));

// Check that the ratio of probPrime to evaluate is the same for all modes.
std::vector<double> ratio(2);
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testGaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
continuousValues.insert(X(2), Vector2(1, 1));

// error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues);

std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
auto error_tree = graph.errorTree(delta.continuous());

std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/tests/testMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ TEST(MixtureFactor, Error) {
continuousValues.insert<double>(X(1), 0);
continuousValues.insert<double>(X(2), 1);

AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
AlgebraicDecisionTree<Key> error_tree =
mixtureFactor.errorTree(continuousValues);

DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1};
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
}

/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/BinaryAllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ class BinaryAllDiff : public Constraint {
const Domains&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented");
}
};

} // namespace gtsam
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/Domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
}
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("Domain::error not implemented");
}

// Return concise string representation, mostly to debug arc consistency.
// Converts from base 0 to base1.
std::string base1Str() const;
Expand Down
5 changes: 5 additions & 0 deletions gtsam_unstable/discrete/SingleValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("SingleValue::error not implemented");
}

/// Calculate value
double operator()(const DiscreteValues& values) const override;

Expand Down

0 comments on commit 42b5218

Please sign in to comment.