Skip to content

Commit

Permalink
Merge pull request #1670 from borglab/hybrid-printerrors
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 7, 2024
2 parents 42b5218 + 7ea1bbc commit 7cd46f8
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 11 deletions.
30 changes: 30 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the result tree.
result = result.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (auto dc = conditional->asDiscrete()) {
// If discrete, add the discrete error in the right branch
result = result.apply(
[dc](const Assignment<Key> &assignment, double leaf_value) {
return leaf_value + dc->error(DiscreteValues(assignment));
});
}
}

return result;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
Expand Down
17 changes: 17 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Error method using HybridValues which returns specific error for
* assignment.
*/
using Base::error;

/**
* @brief Compute log probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which
* to compute the log probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;

Expand Down
79 changes: 79 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,85 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
const KeyFormatter &keyFormatter,
const std::function<bool(const Factor * /*factor*/,
double /*whitenedError*/, size_t /*index*/)>
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);

if (hc->isContinuous()) {
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->errorTree().print("", keyFormatter);
std::cout << "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->errorTree(values.continuous()).print();
std::cout << "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->errorTree().print("", keyFormatter);
}

} else {
continue;
}

std::cout << "\n";
}
std::cout.flush();
}

/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
Expand Down
16 changes: 13 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @{

// TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
// override;
// void print(
// const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

void printErrors(
const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;

// bool equals(const This& fg, double tol = 1e-9) const override;

/// @}
Expand Down
92 changes: 92 additions & 0 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,98 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
}
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::printErrors(
const HybridValues& values, const std::string& str,
const KeyFormatter& keyFormatter,
const std::function<bool(const Factor* /*factor*/, double /*whitenedError*/,
size_t /*index*/)>& printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto&& factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto mf = std::dynamic_pointer_cast<MixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
mf->errorTree(values.nonlinear()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gmf =
std::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gm->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
const double errorValue = (factor != nullptr ? nf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->errorTree().print("", keyFormatter);
std::cout << std::endl;
}

} else {
continue;
}

std::cout << "\n";
}
std::cout.flush();
}

/* ************************************************************************* */
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {
Expand Down
12 changes: 11 additions & 1 deletion gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
protected:
public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
using This = HybridNonlinearFactorGraph; ///< this class
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This

using Values = gtsam::Values; ///< backwards compatibility
Expand Down Expand Up @@ -63,6 +63,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
const std::string& s = "HybridNonlinearFactorGraph",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

/** print errors along with factors*/
void printErrors(
const HybridValues& values,
const std::string& str = "HybridNonlinearFactorGraph: ",
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const std::function<bool(const Factor* /*factor*/,
double /*whitenedError*/, size_t /*index*/)>&
printCondition =
[](const Factor*, double, size_t) { return true; }) const;

/// @}
/// @name Standard Interface
/// @{
Expand Down
39 changes: 39 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3)));
}

/* ****************************************************************************/
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, Error) {
const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);

const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0));

const auto conditional0 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(5), I_1x1, model0),
conditional1 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1);

auto gm =
new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1});
// Create hybrid Bayes net.
HybridBayesNet bayesNet;
bayesNet.push_back(continuousConditional);
bayesNet.emplace_back(gm);
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));

// Create values at which to evaluate.
HybridValues values;
values.insert(asiaKey, 0);
values.insert(X(0), Vector1(-6));
values.insert(X(1), Vector1(1));

AlgebraicDecisionTree<Key> actual_errors =
bayesNet.errorTree(values.continuous());

// Regression.
// Manually added all the error values from the 3 conditional types.
AlgebraicDecisionTree<Key> expected_errors(
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});

EXPECT(assert_equal(expected_errors, actual_errors));
}

/* ****************************************************************************/
// Test Bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
}

/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Check that the bayes net unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridBayesNet &posterior, size_t num_samples = 100) {
Expand Down
Loading

0 comments on commit 7cd46f8

Please sign in to comment.