diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index f52e5eec36..7b13b66460 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -15,7 +15,7 @@ jobs: BOOST_VERSION: 1.67.0 strategy: - fail-fast: false + fail-fast: true matrix: # Github Actions requires a single row to be added to the build matrix. # See https://help.github.com/en/articles/workflow-syntax-for-github-actions. diff --git a/CMakeLists.txt b/CMakeLists.txt index d040f9e82a..7c37099a45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) -set (GTSAM_PRERELEASE_VERSION "a2") +set (GTSAM_PRERELEASE_VERSION "a3") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") if (${GTSAM_VERSION_PATCH} EQUAL 0) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index b5f6c0c4af..8beeb4c4a0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -57,7 +57,7 @@ namespace gtsam { /** Default constructor for I/O */ DecisionTreeFactor(); - /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); /** Constructor from doubles */ @@ -139,14 +139,14 @@ namespace gtsam { /** * Apply binary operator (*this) "op" f * @param f the second argument for op - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree */ DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; /** * Combine frontal variables using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; @@ -154,7 +154,7 @@ namespace gtsam { /** * Combine frontal variables in an Ordering using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 17dfe2c5ff..db20e7223a 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -19,7 +19,7 @@ #pragma once #include -#include +#include #include #include @@ -79,9 +79,9 @@ namespace gtsam { // Add inherited versions of add. using Base::add; - /** Add a DiscretePrior using a table or a string */ + /** Add a DiscreteDistribution using a table or a string */ void add(const DiscreteKey& key, const std::string& spec) { - emplace_shared(key, spec); + emplace_shared(key, spec); } /** Add a DiscreteCondtional */ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0bdc7d7b5a..e8aa4511d8 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace std; using std::stringstream; @@ -38,38 +39,97 @@ using std::pair; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} + +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; @@ -82,7 +142,7 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; + cout << "):\n"; ADT::print(""); cout << endl; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 4a83ff83a0..c3c8a66def 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ + /// Default constructor needed for serialization. DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); + explicit DiscreteConditional(const Signature& signature); /** * Construct from key, parents, and a Signature::Table specifying the @@ -82,31 +89,45 @@ class GTSAM_EXPORT DiscreteConditional const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} - /// No-parent specialization; can also use DiscretePrior. + /// No-parent specialization; can also use DiscreteDistribution. DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal, const Ordering& orderedKeys); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the - * parents of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must - * dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, - * must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; + + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; /// @} /// @name Testable @@ -136,11 +157,6 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - /** Restrict to given parent values, returns DecisionTreeFactor */ DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; @@ -208,23 +224,4 @@ class GTSAM_EXPORT DiscreteConditional template <> struct traits : public Testable {}; -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscretePrior.cpp b/gtsam/discrete/DiscreteDistribution.cpp similarity index 71% rename from gtsam/discrete/DiscretePrior.cpp rename to gtsam/discrete/DiscreteDistribution.cpp index 3941e0199e..7397714709 100644 --- a/gtsam/discrete/DiscretePrior.cpp +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -10,21 +10,23 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscretePrior.cpp + * @file DiscreteDistribution.cpp * @date December 2021 * @author Frank Dellaert */ -#include +#include + +#include namespace gtsam { -void DiscretePrior::print(const std::string& s, - const KeyFormatter& formatter) const { +void DiscreteDistribution::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } -double DiscretePrior::operator()(size_t value) const { +double DiscreteDistribution::operator()(size_t value) const { if (nrFrontals() != 1) throw std::invalid_argument( "Single value operator can only be invoked on single-variable " @@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const { return Base::operator()(values); } -std::vector DiscretePrior::pmf() const { +std::vector DiscreteDistribution::pmf() const { if (nrFrontals() != 1) throw std::invalid_argument( - "DiscretePrior::pmf only defined for single-variable priors"); + "DiscreteDistribution::pmf only defined for single-variable priors"); const size_t nrValues = cardinalities_.at(keys_[0]); std::vector array; array.reserve(nrValues); diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscreteDistribution.h similarity index 63% rename from gtsam/discrete/DiscretePrior.h rename to gtsam/discrete/DiscreteDistribution.h index 9ac8acb17a..fae6e355bd 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscretePrior.h + * @file DiscreteDistribution.h * @date December 2021 * @author Frank Dellaert */ @@ -20,6 +20,7 @@ #include #include +#include namespace gtsam { @@ -27,7 +28,7 @@ namespace gtsam { * A prior probability on a set of discrete variables. * Derives from DiscreteConditional */ -class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { +class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { public: using Base = DiscreteConditional; @@ -35,35 +36,36 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { /// @{ /// Default constructor needed for serialization. - DiscretePrior() {} + DiscreteDistribution() {} /// Constructor from factor. - DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + explicit DiscreteDistribution(const DecisionTreeFactor& f) + : Base(f.size(), f) {} /** * Construct from a Signature. * - * Example: DiscretePrior P(D % "3/2"); + * Example: DiscreteDistribution P(D % "3/2"); */ - DiscretePrior(const Signature& s) : Base(s) {} + explicit DiscreteDistribution(const Signature& s) : Base(s) {} /** - * Construct from key and a Signature::Table specifying the - * conditional probability table (CPT). + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). * - * Example: DiscretePrior P(D, table); + * Example: DiscreteDistribution P(D, {0.4, 0.6}); */ - DiscretePrior(const DiscreteKey& key, const Signature::Table& table) - : Base(Signature(key, {}, table)) {} + DiscreteDistribution(const DiscreteKey& key, const std::vector& spec) + : DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {} /** - * Construct from key and a string specifying the conditional - * probability table (CPT). + * Construct from key and a string specifying the probability mass function + * (PMF). * - * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + * Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9"); */ - DiscretePrior(const DiscreteKey& key, const std::string& spec) - : DiscretePrior(Signature(key, {}, spec)) {} + DiscreteDistribution(const DiscreteKey& key, const std::string& spec) + : DiscreteDistribution(Signature(key, {}, spec)) {} /// @} /// @name Testable @@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { /// @} }; -// DiscretePrior +// DiscreteDistribution // traits template <> -struct traits : public Testable {}; +struct traits : public Testable {}; } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 218b790e88..7ce4bd9021 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, bool showZero = true) const; @@ -86,14 +95,18 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; + DiscreteConditional marginal(gtsam::Key key) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* choose( const gtsam::DiscreteValues& parentsValues) const; gtsam::DecisionTreeFactor* likelihood( @@ -115,11 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { std::map> names) const; }; -#include -virtual class DiscretePrior : gtsam::DiscreteConditional { - DiscretePrior(); - DiscretePrior(const gtsam::DecisionTreeFactor& f); - DiscretePrior(const gtsam::DiscreteKey& key, string spec); +#include +virtual class DiscreteDistribution : gtsam::DiscreteConditional { + DiscreteDistribution(); + DiscreteDistribution(const gtsam::DecisionTreeFactor& f); + DiscreteDistribution(const gtsam::DiscreteKey& key, string spec); + DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector spec); void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 594134edf7..92145b8b76 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -17,10 +17,12 @@ * @author Duy-Nguyen Ta */ -#include -#include -#include #include +#include +#include +#include +#include + #include using namespace boost::assign; @@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors) } /* ************************************************************************* */ -TEST_UNSAFE( DecisionTreeFactor, multiplication) -{ - DiscreteKey v0(0,2), v1(1,2), v2(2,2); +TEST(DecisionTreeFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); - DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); - - DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, static_cast(prior) * f1)); + CHECK(assert_equal(expected, f1 * prior)); + // Multiply two factors + DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); DecisionTreeFactor actual = f1 * f2; - CHECK(assert_equal(expected, actual)); + DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3fb67a615c..1256595170 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -34,20 +34,21 @@ using namespace gtsam; TEST(DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); - EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); - EXPECT(expected.endParents() == expected.end()); - EXPECT(expected.endFrontals() == expected.beginParents()); + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(expected, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { r3 += 1.0, 4.0; table += r1, r2, r3; DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); EXPECT(assert_equal(expected1, actual1, 1e-9)); @@ -68,41 +70,141 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); +} + +/* ************************************************************************* */ +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + // And for good measure: + EXPECT(assert_equal(expected, actual)); + } } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { - DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional expected(2, factor); - auto actual = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(expected, *actual, 1e-5)); +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); + EXPECT((frontalsA == KeyVector{1})); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); + KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); + EXPECT((frontalsB == KeyVector{0})); } /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp similarity index 53% rename from gtsam/discrete/tests/testDiscretePrior.cpp rename to gtsam/discrete/tests/testDiscreteDistribution.cpp index 23f093b229..5c0c42e737 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -11,47 +11,66 @@ /* * @file testDiscretePrior.cpp - * @brief unit tests for DiscretePrior + * @brief unit tests for DiscreteDistribution * @author Frank dellaert * @date December 2021 */ #include -#include +#include #include -using namespace std; using namespace gtsam; static const DiscreteKey X(0, 2); /* ************************************************************************* */ -TEST(DiscretePrior, constructors) { - DiscretePrior actual(X % "2/3"); +TEST(DiscreteDistribution, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscreteDistribution expected(f); + + DiscreteDistribution actual(X % "2/3"); EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual.nrParents()); - DecisionTreeFactor f(X, "0.4 0.6"); - DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); + + const std::vector pmf{0.4, 0.6}; + DiscreteDistribution actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); +} + +/* ************************************************************************* */ +TEST(DiscreteDistribution, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteDistribution prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); } /* ************************************************************************* */ -TEST(DiscretePrior, operator) { - DiscretePrior prior(X % "2/3"); +TEST(DiscreteDistribution, operator) { + DiscreteDistribution prior(X % "2/3"); EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); } /* ************************************************************************* */ -TEST(DiscretePrior, pmf) { - DiscretePrior prior(X % "2/3"); - vector expected {0.4, 0.6}; - EXPECT(prior.pmf() == expected); +TEST(DiscreteDistribution, pmf) { + DiscreteDistribution prior(X % "2/3"); + std::vector expected{0.4, 0.6}; + EXPECT(prior.pmf() == expected); } /* ************************************************************************* */ -TEST(DiscretePrior, sample) { - DiscretePrior prior(X % "2/3"); +TEST(DiscreteDistribution, sample) { + DiscreteDistribution prior(X % "2/3"); prior.sample(); } diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index 7bee4c9fb5..f392221222 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -154,7 +154,8 @@ namespace gtsam { /** Unnormalized probability. O(n) */ double probPrime(const VectorValues& c) const { - return exp(-0.5 * error(c)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(c)); } /** diff --git a/gtsam/linear/tests/testGaussianFactorGraph.cpp b/gtsam/linear/tests/testGaussianFactorGraph.cpp index bb07a36aae..41464a1109 100644 --- a/gtsam/linear/tests/testGaussianFactorGraph.cpp +++ b/gtsam/linear/tests/testGaussianFactorGraph.cpp @@ -426,6 +426,7 @@ TEST(GaussianFactorGraph, hessianDiagonal) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ TEST(GaussianFactorGraph, DenseSolve) { GaussianFactorGraph fg = createSimpleGaussianFactorGraph(); VectorValues expected = fg.optimize(); @@ -433,6 +434,28 @@ TEST(GaussianFactorGraph, DenseSolve) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianFactorGraph, ProbPrime) { + GaussianFactorGraph gfg; + gfg.emplace_shared(1, I_1x1, Z_1x1, + noiseModel::Isotropic::Sigma(1, 1.0)); + + VectorValues values; + values.insert(1, I_1x1); + + // We are testing the normal distribution PDF where info matrix Σ = 1, + // mean mu = 0 and x = 1. + // Therefore factor squared error: y = 0.5 * (Σ*x - mu)^2 = + // 0.5 * (1.0 - 0)^2 = 0.5 + // NOTE the 0.5 constant is a part of the factor error. + EXPECT_DOUBLES_EQUAL(0.5, gfg.error(values), 1e-12); + + // The gaussian PDF value is: exp^(-0.5 * (Σ*x - mu)^2) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-0.5 * (1.0)^2) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 0d1ed31487..89236ea878 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -45,7 +45,8 @@ template class FactorGraph; /* ************************************************************************* */ double NonlinearFactorGraph::probPrime(const Values& values) const { - return exp(-0.5 * error(values)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(values)); } /* ************************************************************************* */ @@ -54,9 +55,14 @@ void NonlinearFactorGraph::print(const std::string& str, const KeyFormatter& key for (size_t i = 0; i < factors_.size(); i++) { stringstream ss; ss << "Factor " << i << ": "; - if (factors_[i] != nullptr) factors_[i]->print(ss.str(), keyFormatter); - cout << endl; + if (factors_[i] != nullptr) { + factors_[i]->print(ss.str(), keyFormatter); + cout << "\n"; + } else { + cout << ss.str() << "nullptr\n"; + } } + std::cout.flush(); } /* ************************************************************************* */ @@ -80,8 +86,9 @@ void NonlinearFactorGraph::printErrors(const Values& values, const std::string& factor->print(ss.str(), keyFormatter); cout << "error = " << errorValue << "\n"; } - cout << endl; // only one "endl" at end might be faster, \n for each factor + cout << "\n"; } + std::cout.flush(); } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 160e469241..ea8748f63b 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -90,7 +90,7 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ + /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ double error(const Values& values) const; /** Unnormalized probability. O(n) */ diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index a0a7329dd8..602b2afe3c 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -11,7 +11,7 @@ namespace gtsam { // ###### #include -template virtual class BetweenFactor : gtsam::NoiseModelFactor { diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12a60d5cb1..0499e72154 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,7 +13,7 @@ import unittest -from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering from gtsam.utils.test_case import GtsamTestCase @@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase): """Tests for DecisionTreeFactors.""" def setUp(self): - A = (12, 3) - B = (5, 2) - self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") def test_enumerate(self): actual = self.factor.enumerate() _, values = zip(*actual) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscreteDistribution, i.e., Bayes Law! + prior = DiscreteDistribution(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index bdd5a05464..36f0d153d9 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -14,7 +14,7 @@ import unittest from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, - DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) + DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) from gtsam.utils.test_case import GtsamTestCase @@ -74,7 +74,7 @@ def test_Asia(self): for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) - expected2 = DiscretePrior(Bronchitis, "11/9") + expected2 = DiscreteDistribution(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 0ae66c2d40..241a5f0be9 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -16,6 +16,13 @@ from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" @@ -36,6 +43,53 @@ def test_single_value_versions(self): actual = conditional.sample(2) self.assertIsInstance(actual, int) + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A, "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -48,8 +102,7 @@ def test_markdown(self): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") - expected = \ - " *P(A|B,C):*\n\n" \ + expected = " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscreteDistribution.py similarity index 74% rename from python/gtsam/tests/test_DiscretePrior.py rename to python/gtsam/tests/test_DiscreteDistribution.py index 2c923589ce..fa999fd6b5 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -14,7 +14,7 @@ import unittest import numpy as np -from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution from gtsam.utils.test_case import GtsamTestCase X = 0, 2 @@ -25,32 +25,36 @@ class TestDiscretePrior(GtsamTestCase): def test_constructor(self): """Test various constructors.""" - actual = DiscretePrior(X, "2/3") keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") - expected = DiscretePrior(f) + expected = DiscreteDistribution(f) + + actual = DiscreteDistribution(X, "2/3") self.gtsamAssertEquals(actual, expected) + actual2 = DiscreteDistribution(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) + def test_operator(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") self.assertAlmostEqual(prior(0), 0.4) self.assertAlmostEqual(prior(1), 0.6) def test_pmf(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") expected = np.array([0.4, 0.6]) np.testing.assert_allclose(expected, prior.pmf()) def test_sample(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") actual = prior.sample() self.assertIsInstance(actual, int) def test_markdown(self): """Test the _repr_markdown_ method.""" - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") expected = " *P(0):*\n\n" \ "|0|value|\n" \ "|:-:|:-:|\n" \ diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index 4dec08f45c..8a360e4542 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -107,6 +107,24 @@ TEST( NonlinearFactorGraph, probPrime ) DOUBLES_EQUAL(expected,actual,0); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, ProbPrime2) { + NonlinearFactorGraph fg; + fg.emplace_shared>(1, 0.0, + noiseModel::Isotropic::Sigma(1, 1.0)); + + Values values; + values.insert(1, 1.0); + + // The prior factor squared error is: 0.5. + EXPECT_DOUBLES_EQUAL(0.5, fg.error(values), 1e-12); + + // The probability value is: exp^(-factor_error) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-factor_error) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, fg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ TEST( NonlinearFactorGraph, linearize ) {