Skip to content

Commit

Permalink
Merge branch 'develop' into feature/discrete_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jul 17, 2023
2 parents 73eb405 + 016f77b commit e749b6d
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 155 deletions.
47 changes: 19 additions & 28 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}
const ADT& potentials)
: DiscreteFactor(keys.indices(), keys.cardinalities()), ADT(potentials) {}

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
: DiscreteFactor(c.keys(), c.cardinalities()),
AlgebraicDecisionTree<Key>(c) {}

/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
Expand Down Expand Up @@ -182,15 +179,12 @@ namespace gtsam {
}

/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
std::vector<double> DecisionTreeFactor::probabilities() const {
std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}
return result;
return probs;
}

/* ************************************************************************ */
Expand Down Expand Up @@ -288,29 +282,26 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const vector<double>& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
const string& table)
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
Expand Down
11 changes: 2 additions & 9 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ namespace gtsam {
typedef std::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;

protected:
std::map<Key, size_t> cardinalities_;

public:
/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -154,8 +150,6 @@ namespace gtsam {

static double safe_div(const double& a, const double& b);

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);
Expand Down Expand Up @@ -214,8 +208,8 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;

/**
* @brief Prune the decision tree of discrete variables.
Expand Down Expand Up @@ -295,7 +289,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ using namespace std;

namespace gtsam {

/* ************************************************************************ */
DiscreteKeys DiscreteFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

/* ************************************************************************* */
double DiscreteFactor::error(const DiscreteValues& values) const {
return -std::log((*this)(values));
Expand Down
49 changes: 37 additions & 12 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,35 @@ class HybridValues;
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {

public:

public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
typedef DiscreteFactor This; ///< This class
typedef std::shared_ptr<DiscreteFactor>
shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class

using Values = DiscreteValues; ///< backwards compatibility
using Values = DiscreteValues; ///< backwards compatibility

public:
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;

public:
/// @name Standard Constructors
/// @{

/** Default constructor creates empty factor */
DiscreteFactor() {}

/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/**
* Construct from container of keys and map of cardinalities.
* This constructor is used internally from derived factor constructors,
* either from a container of keys or from a boost::assign::list_of.
*/
template <typename CONTAINER>
DiscreteFactor(const CONTAINER& keys,
const std::map<Key, size_t> cardinalities = {})
: Base(keys), cardinalities_(cardinalities) {}

/// @}
/// @name Testable
Expand All @@ -77,6 +84,13 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// @name Standard Interface
/// @{

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

std::map<Key, size_t> cardinalities() const { return cardinalities_; }

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;

Expand Down Expand Up @@ -124,6 +138,17 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
const Names& names = {}) const = 0;

/// @}

private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
#endif
};
// DiscreteFactor

Expand Down
26 changes: 9 additions & 17 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
* @file TableFactor.cpp
* @brief discrete factor
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>

Expand All @@ -33,8 +34,7 @@ TableFactor::TableFactor() {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const TableFactor& potentials)
: DiscreteFactor(dkeys.indices()),
cardinalities_(potentials.cardinalities_) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()) {
sparse_table_ = potentials.sparse_table_;
denominators_ = potentials.denominators_;
sorted_dkeys_ = discreteKeys();
Expand All @@ -44,18 +44,22 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const Eigen::SparseVector<double>& table)
: DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
: DiscreteFactor(dkeys.indices(), dkeys.cardinalities()),
sparse_table_(table.size()) {
sparse_table_ = table;
double denom = table.size();
for (const DiscreteKey& dkey : dkeys) {
cardinalities_.insert(dkey);
denom /= dkey.second;
denominators_.insert(std::pair<Key, double>(dkey.first, denom));
}
sorted_dkeys_ = discreteKeys();
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}

/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}

/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
Expand Down Expand Up @@ -435,18 +439,6 @@ std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
return result;
}

/* ************************************************************************ */
DiscreteKeys TableFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

// Print out header.
/* ************************************************************************ */
string TableFactor::markdown(const KeyFormatter& keyFormatter,
Expand Down
25 changes: 11 additions & 14 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* @file TableFactor.h
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/

#pragma once
Expand All @@ -32,6 +32,7 @@

namespace gtsam {

class DiscreteConditional;
class HybridValues;

/**
Expand All @@ -44,8 +45,6 @@ class HybridValues;
*/
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
protected:
/// Map of Keys and their cardinalities.
std::map<Key, size_t> cardinalities_;
/// SparseVector of nonzero probabilities.
Eigen::SparseVector<double> sparse_table_;

Expand All @@ -57,10 +56,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {

/**
* @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* product of arrays in O(1)
* Example)
* v0 | v1 | val
* 0 | 0 | 10
* 0 | 1 | 21
* 1 | 0 | 32
* 1 | 1 | 43
Expand All @@ -75,13 +74,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
*/
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}

/// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);

/// Convert probability table given as string to SparseVector.
Expand Down Expand Up @@ -142,6 +141,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}

/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);

/// @}
/// @name Testable
/// @{
Expand Down Expand Up @@ -180,8 +182,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {

static double safe_div(const double& a, const double& b);

size_t cardinality(Key j) const { return cardinalities_.at(j); }

/// divide by factor f (safely)
TableFactor operator/(const TableFactor& f) const {
return apply(f, safe_div);
Expand Down Expand Up @@ -274,9 +274,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

/**
* @brief Prune the decision tree of discrete variables.
*
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ TEST( DecisionTreeFactor, constructors)

// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);

// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
}

/* ************************************************************************* */
Expand Down
Loading

0 comments on commit e749b6d

Please sign in to comment.