Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small runtime improvement for hybrid #1844

Merged
merged 7 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
std::string value = valueFormatter(constant_);
const std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
Expand Down Expand Up @@ -306,7 +306,8 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
const std::string label = labelFormatter(label_);
dellaert marked this conversation as resolved.
Show resolved Hide resolved
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label
<< "\"]\n";
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {
Expand Down
30 changes: 14 additions & 16 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ namespace gtsam {
size_t i;
ADT result(*this);
for (i = 0; i < nrFrontals; i++) {
Key j = keys()[i];
Key j = keys_[i];
result = result.combine(j, cardinality(j), op);
}

// create new factor, note we start keys after nrFrontals
// Create new factor, note we start with keys after nrFrontals:
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
for (; i < keys_.size(); i++) {
Key j = keys_[i];
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
Expand All @@ -179,24 +179,22 @@ namespace gtsam {
result = result.combine(j, cardinality(j), op);
}

// create new factor, note we collect keys that are not in frontalKeys
/*
Due to branch merging, the labels in `result` may be missing some keys
Create new factor, note we collect keys that are not in frontalKeys.

Due to branch merging, the labels in `result` may be missing some keys.
E.g. After branch merging, we may get a ADT like:
Leaf [2] 1.0204082

This is missing the key values used for branching.
Hence, code below traverses the original keys and omits those in
frontalKeys. We loop over cardinalities, which is O(n) even for a map, and
then "contains" is a binary search on a small vector.
*/
KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
// Get the difference of the frontalKeys and the factor keys using set_difference
std::sort(keys_.begin(), keys_.end());
std::sort(frontalKeys_.begin(), frontalKeys_.end());
std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
frontalKeys_.end(), back_inserter(difference));

DiscreteKeys dkeys;
for (Key key : difference) {
dkeys.push_back(DiscreteKey(key, cardinality(key)));
for (auto&& [key, cardinality] : cardinalities_) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
if (!frontalKeys.contains(key)) {
dkeys.push_back(DiscreteKey(key, cardinality));
}
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
}
Expand Down
121 changes: 28 additions & 93 deletions gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>

using namespace std;
Expand Down Expand Up @@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
muls = 0;
adds = 0;
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " <<
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
<< endl;
cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
<< " adds" << endl;
#endif
resetCounts();
}
Expand Down Expand Up @@ -131,37 +126,35 @@ ADT create(const Signature& signature) {
static size_t count = 0;
const DiscreteKey& key = signature.key();
std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str();
dot(p, DOTfile);
return p;
}

/* ************************************************************************* */
namespace asiaCPTs {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);

ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
} // namespace asiaCPTs

/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);

resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
using namespace asiaCPTs;

// Create joint
resetCounts();
gttic_(asiaJoint);
ADT joint = pA;
dot(joint, "Asia-A");
joint = apply(joint, pS, &mul);
Expand All @@ -183,11 +176,12 @@ TEST(ADT, joint) {
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint");
}

/* ************************************************************************* */
TEST(ADT, combine) {
using namespace asiaCPTs;

// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA;
Expand All @@ -203,13 +197,11 @@ TEST(ADT, joint) {
}

/* ************************************************************************* */
// test Inference with joint
// test Inference with joint, created using different ordering
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);

resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
Expand All @@ -218,15 +210,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");

// Create joint
// Create joint, note different ordering than above: different tree!
resetCounts();
gttic_(asiaProd);
ADT joint = pA;
dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul);
Expand All @@ -248,14 +234,9 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product");

resetCounts();
gttic_(asiaSum);
ADT marginal = joint;
marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST");
Expand All @@ -270,35 +251,23 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum");
}

/* ************************************************************************* */
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);

resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs);
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");

// Create joint
resetCounts();
gttic_(asiaFG);
ADT fg = pS;
fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul);
Expand All @@ -312,14 +281,9 @@ TEST(ADT, factor_graph) {
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG");

resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_);
Expand All @@ -335,83 +299,54 @@ TEST(ADT, factor_graph) {
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize");

// BLESTX

// Eliminate X
resetCounts();
gttic_(elimX);
ADT fE = pX;
dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X");

// Eliminate T
resetCounts();
gttic_(elimT);
ADT fLE = pT;
fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T");

// Eliminate S
resetCounts();
gttic_(elimS);
ADT fBL = pS;
fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S");

// Eliminate E
resetCounts();
gttic_(elimE);
ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E");

// Eliminate L
resetCounts();
gttic_(elimL);
ADT fB = fBL;
fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L");
}

Expand Down
Loading
Loading