Skip to content

Commit

Permalink
replace error with errorTree
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jan 5, 2024
1 parent ee5bda9 commit 7ea1bbc
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,15 @@ HybridValues HybridBayesNet::sample() const {
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> result(0.0);

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->error(continuousValues);
result = result + gm->errorTree(continuousValues);

} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the error and add it to the result
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;

/**
* @brief Error method using HybridValues which returns specific error for
Expand Down
8 changes: 4 additions & 4 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->error(values.continuous()).print("", keyFormatter);
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
Expand All @@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
std::cout << "error = ";
hc->asDiscrete()->error().print("", keyFormatter);
hc->asDiscrete()->errorTree().print("", keyFormatter);
std::cout << "\n";
} else {
// Is hybrid
std::cout << "error = ";
hc->asMixture()->error(values.continuous()).print();
hc->asMixture()->errorTree(values.continuous()).print();
std::cout << "\n";
}
}
Expand All @@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->error().print("", keyFormatter);
df->errorTree().print("", keyFormatter);
}

} else {
Expand Down
8 changes: 4 additions & 4 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
mf->error(values.nonlinear()).print("", keyFormatter);
mf->errorTree(values.nonlinear()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gmf =
Expand All @@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gmf->error(values.continuous()).print("", keyFormatter);
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
Expand All @@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
gm->error(values.continuous()).print("", keyFormatter);
gm->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
}
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
Expand Down Expand Up @@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors(
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = ";
df->error().print("", keyFormatter);
df->errorTree().print("", keyFormatter);
std::cout << std::endl;
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) {
values.insert(X(1), Vector1(1));

AlgebraicDecisionTree<Key> actual_errors =
bayesNet.error(values.continuous());
bayesNet.errorTree(values.continuous());

// Regression.
// Manually added all the error values from the 3 conditional types.
Expand Down

0 comments on commit 7ea1bbc

Please sign in to comment.