Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TableFactor Improvements #1556

Merged
merged 10 commits into from
Jul 17, 2023
47 changes: 19 additions & 28 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
const ADT& potentials)
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
: DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c) {}

/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
Expand Down Expand Up @@ -182,15 +179,12 @@ namespace gtsam {
}

/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
std::vector<double> DecisionTreeFactor::probabilities() const {
std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}
return result;
return probs;
}

/* ************************************************************************ */
Expand Down Expand Up @@ -288,29 +282,26 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const vector<double>& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const string& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
Expand Down
11 changes: 2 additions & 9 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;

protected:
std::map<Key, size_t> cardinalities_;

public:
/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -119,8 +115,6 @@ namespace gtsam {

static double safe_div(const double& a, const double& b);

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
Expand Down Expand Up @@ -179,8 +173,8 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;

/**
* @brief Prune the decision tree of discrete variables.
Expand Down Expand Up @@ -260,7 +254,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ using namespace std;

namespace gtsam {

/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
Expand Down
49 changes: 37 additions & 12 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,35 @@ class HybridValues;
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {

public:

public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor>
shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class

using Values = DiscreteValues; ///< backwards compatibility
using Values = DiscreteValues; ///< backwards compatibility

public:
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;

public:
/// @name Standard Constructors
/// @{

/** Default constructor creates empty factor */
DiscreteFactor() {}

/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/**
* Construct from container of keys and map of cardinalities.
* This constructor is used internally from derived factor constructors,
* either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}

/// @}
/// @name Testable
Expand All @@ -77,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// @name Standard Interface
/// @{

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

std::map<Key, size_t> cardinalities() const { return cardinalities_; }

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;

Expand Down Expand Up @@ -124,6 +138,17 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
const Names& names = {}) const = 0;

/// @}

private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};
// DiscreteFactor

Expand Down
26 changes: 9 additions & 17 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
* @file TableFactor.cpp
* @brief discrete factor
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>

Expand All @@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()),
cardinalities_(potentials.cardinalities_) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys();
Expand All @@ -44,18 +44,22 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table;
double denom = table.size();
for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}

/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
Expand Down Expand Up @@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result;
}

/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,
Expand Down
25 changes: 11 additions & 14 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* @file TableFactor.h
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/

#pragma once
Expand All @@ -32,6 +32,7 @@

namespace gtsam {

class DiscreteConditional;
class HybridValues;

/**
Expand All @@ -44,8 +45,6 @@ class HybridValues;
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;

Expand All @@ -57,10 +56,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {

/**
* @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
Expand All @@ -75,13 +74,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
*/
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}

/// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);

/// Convert probability table given as string to SparseVector.
Expand Down Expand Up @@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}

/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);

/// @}
/// @name Testable
/// @{
Expand Down Expand Up @@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {

static double safe_div(const double& a, const double& b);

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div);
Expand Down Expand Up @@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

/**
* @brief Prune the decision tree of discrete variables.
*
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ TEST( DecisionTreeFactor, constructors)

// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);

// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9);
}

/* ************************************************************************* */
Expand Down
Loading