Skip to content

Commit

Permalink
Rename DiscretePrior -> DiscreteDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jan 16, 2022
1 parent 0b11b12 commit 4235334
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 67 deletions.
6 changes: 3 additions & 3 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#pragma once

#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

Expand Down Expand Up @@ -79,9 +79,9 @@ namespace gtsam {
// Add inherited versions of add.
using Base::add;

/** Add a DiscretePrior using a table or a string */
/** Add a DiscreteDistribution using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscretePrior>(key, spec);
emplace_shared<DiscreteDistribution>(key, spec);
}

/** Add a DiscreteCondtional */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GTSAM_EXPORT DiscreteConditional
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}

/// No-parent specialization; can also use DiscretePrior.
/// No-parent specialization; can also use DiscreteDistribution.
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@
* -------------------------------------------------------------------------- */

/**
* @file DiscretePrior.cpp
* @file DiscreteDistribution.cpp
* @date December 2021
* @author Frank Dellaert
*/

#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>

#include <vector>

namespace gtsam {

void DiscretePrior::print(const std::string& s,
const KeyFormatter& formatter) const {
void DiscreteDistribution::print(const std::string& s,
const KeyFormatter& formatter) const {
Base::print(s, formatter);
}

double DiscretePrior::operator()(size_t value) const {
double DiscreteDistribution::operator()(size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value operator can only be invoked on single-variable "
Expand All @@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const {
return Base::operator()(values);
}

std::vector<double> DiscretePrior::pmf() const {
std::vector<double> DiscreteDistribution::pmf() const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"DiscretePrior::pmf only defined for single-variable priors");
"DiscreteDistribution::pmf only defined for single-variable priors");
const size_t nrValues = cardinalities_.at(keys_[0]);
std::vector<double> array;
array.reserve(nrValues);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */

/**
* @file DiscretePrior.h
* @file DiscreteDistribution.h
* @date December 2021
* @author Frank Dellaert
*/
Expand All @@ -20,50 +20,52 @@
#include <gtsam/discrete/DiscreteConditional.h>

#include <string>
#include <vector>

namespace gtsam {

/**
* A prior probability on a set of discrete variables.
* Derives from DiscreteConditional
*/
class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
public:
using Base = DiscreteConditional;

/// @name Standard Constructors
/// @{

/// Default constructor needed for serialization.
DiscretePrior() {}
DiscreteDistribution() {}

/// Constructor from factor.
DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {}
explicit DiscreteDistribution(const DecisionTreeFactor& f)
: Base(f.size(), f) {}

/**
* Construct from a Signature.
*
* Example: DiscretePrior P(D % "3/2");
* Example: DiscreteDistribution P(D % "3/2");
*/
DiscretePrior(const Signature& s) : Base(s) {}
explicit DiscreteDistribution(const Signature& s) : Base(s) {}

/**
* Construct from key and a vector of floats specifying the probability mass
* function (PMF).
*
* Example: DiscretePrior P(D, {0.4, 0.6});
* Example: DiscreteDistribution P(D, {0.4, 0.6});
*/
DiscretePrior(const DiscreteKey& key, const std::vector<double>& spec)
: DiscretePrior(Signature(key, {}, Signature::Table{spec})) {}
DiscreteDistribution(const DiscreteKey& key, const std::vector<double>& spec)
: DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {}

/**
* Construct from key and a string specifying the probability mass function
* (PMF).
*
* Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9");
* Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9");
*/
DiscretePrior(const DiscreteKey& key, const std::string& spec)
: DiscretePrior(Signature(key, {}, spec)) {}
DiscreteDistribution(const DiscreteKey& key, const std::string& spec)
: DiscreteDistribution(Signature(key, {}, spec)) {}

/// @}
/// @name Testable
Expand Down Expand Up @@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {

/// @}
};
// DiscretePrior
// DiscreteDistribution

// traits
template <>
struct traits<DiscretePrior> : public Testable<DiscretePrior> {};
struct traits<DiscreteDistribution> : public Testable<DiscreteDistribution> {};

} // namespace gtsam
12 changes: 6 additions & 6 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
std::map<gtsam::Key, std::vector<std::string>> names) const;
};

#include <gtsam/discrete/DiscretePrior.h>
virtual class DiscretePrior : gtsam::DiscreteConditional {
DiscretePrior();
DiscretePrior(const gtsam::DecisionTreeFactor& f);
DiscretePrior(const gtsam::DiscreteKey& key, string spec);
DiscretePrior(const gtsam::DiscreteKey& key, std::vector<double> spec);
#include <gtsam/discrete/DiscreteDistribution.h>
virtual class DiscreteDistribution : gtsam::DiscreteConditional {
DiscreteDistribution();
DiscreteDistribution(const gtsam::DecisionTreeFactor& f);
DiscreteDistribution(const gtsam::DiscreteKey& key, string spec);
DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector<double> spec);
void print(string s = "Discrete Prior\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>

#include <boost/assign/std/map.hpp>
Expand Down Expand Up @@ -56,8 +56,8 @@ TEST( DecisionTreeFactor, constructors)
TEST(DecisionTreeFactor, multiplication) {
DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);

// Multiply with a DiscretePrior, i.e., Bayes Law!
DiscretePrior prior(v1 % "1/3");
// Multiply with a DiscreteDistribution, i.e., Bayes Law!
DiscreteDistribution prior(v1 % "1/3");
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,41 @@

/*
* @file testDiscretePrior.cpp
* @brief unit tests for DiscretePrior
* @brief unit tests for DiscreteDistribution
* @author Frank dellaert
* @date December 2021
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>

using namespace std;
using namespace gtsam;

static const DiscreteKey X(0, 2);

/* ************************************************************************* */
TEST(DiscretePrior, constructors) {
TEST(DiscreteDistribution, constructors) {
DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f);
DiscreteDistribution expected(f);

DiscretePrior actual(X % "2/3");
DiscreteDistribution actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents());
EXPECT(assert_equal(expected, actual, 1e-9));

const vector<double> pmf{0.4, 0.6};
DiscretePrior actual2(X, pmf);
const std::vector<double> pmf{0.4, 0.6};
DiscreteDistribution actual2(X, pmf);
EXPECT_LONGS_EQUAL(1, actual2.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual2.nrParents());
EXPECT(assert_equal(expected, actual2, 1e-9));
}

/* ************************************************************************* */
TEST(DiscretePrior, Multiply) {
TEST(DiscreteDistribution, Multiply) {
DiscreteKey A(0, 2), B(1, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscretePrior prior(B, "1/2");
DiscreteDistribution prior(B, "1/2");
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)

EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
Expand All @@ -56,22 +55,22 @@ TEST(DiscretePrior, Multiply) {
}

/* ************************************************************************* */
TEST(DiscretePrior, operator) {
DiscretePrior prior(X % "2/3");
TEST(DiscreteDistribution, operator) {
DiscreteDistribution prior(X % "2/3");
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
}

/* ************************************************************************* */
TEST(DiscretePrior, pmf) {
DiscretePrior prior(X % "2/3");
vector<double> expected {0.4, 0.6};
EXPECT(prior.pmf() == expected);
TEST(DiscreteDistribution, pmf) {
DiscreteDistribution prior(X % "2/3");
std::vector<double> expected{0.4, 0.6};
EXPECT(prior.pmf() == expected);
}

/* ************************************************************************* */
TEST(DiscretePrior, sample) {
DiscretePrior prior(X % "2/3");
TEST(DiscreteDistribution, sample) {
DiscreteDistribution prior(X % "2/3");
prior.sample();
}

Expand Down
6 changes: 3 additions & 3 deletions python/gtsam/tests/test_DecisionTreeFactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import unittest

from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering
from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering
from gtsam.utils.test_case import GtsamTestCase


Expand All @@ -36,8 +36,8 @@ def test_multiplication(self):
v1 = (1, 2)
v2 = (2, 2)

# Multiply with a DiscretePrior, i.e., Bayes Law!
prior = DiscretePrior(v1, [1, 3])
# Multiply with a DiscreteDistribution, i.e., Bayes Law!
prior = DiscreteDistribution(v1, [1, 3])
f1 = DecisionTreeFactor([v0, v1], "1 2 3 4")
expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3")
self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected)
Expand Down
4 changes: 2 additions & 2 deletions python/gtsam/tests/test_DiscreteBayesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest

from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscretePrior, DiscreteValues, Ordering)
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase


Expand Down Expand Up @@ -74,7 +74,7 @@ def test_Asia(self):
for j in range(8):
ordering.push_back(j)
chordal = fg.eliminateSequential(ordering)
expected2 = DiscretePrior(Bronchitis, "11/9")
expected2 = DiscreteDistribution(Bronchitis, "11/9")
self.gtsamAssertEquals(chordal.at(7), expected2)

# solve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import unittest

import numpy as np
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution
from gtsam.utils.test_case import GtsamTestCase

X = 0, 2
Expand All @@ -28,33 +28,33 @@ def test_constructor(self):
keys = DiscreteKeys()
keys.push_back(X)
f = DecisionTreeFactor(keys, "0.4 0.6")
expected = DiscretePrior(f)
actual = DiscretePrior(X, "2/3")
expected = DiscreteDistribution(f)

actual = DiscreteDistribution(X, "2/3")
self.gtsamAssertEquals(actual, expected)
actual2 = DiscretePrior(X, [0.4, 0.6])

actual2 = DiscreteDistribution(X, [0.4, 0.6])
self.gtsamAssertEquals(actual2, expected)

def test_operator(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
self.assertAlmostEqual(prior(0), 0.4)
self.assertAlmostEqual(prior(1), 0.6)

def test_pmf(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
expected = np.array([0.4, 0.6])
np.testing.assert_allclose(expected, prior.pmf())

def test_sample(self):
prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
actual = prior.sample()
self.assertIsInstance(actual, int)

def test_markdown(self):
"""Test the _repr_markdown_ method."""

prior = DiscretePrior(X, "2/3")
prior = DiscreteDistribution(X, "2/3")
expected = " *P(0):*\n\n" \
"|0|value|\n" \
"|:-:|:-:|\n" \
Expand Down

0 comments on commit 4235334

Please sign in to comment.