Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Discrete Elimination #1590

Merged
merged 22 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
47f3908
small improvements
varunagrawal Jul 3, 2023
baf25de
initial changes
varunagrawal Jul 8, 2023
9c88e3e
Use TableFactor in hybrid elimination
varunagrawal Jul 13, 2023
f238ba5
TableFactor constructor from DecisionTreeFactor and AlgebraicDecision…
varunagrawal Jul 17, 2023
a581788
simplify return
varunagrawal Jul 17, 2023
8462624
update HybridFactorGraph wrapper
varunagrawal Jul 3, 2023
c8e9a57
unary apply methods for TableFactor
varunagrawal Jul 12, 2023
2b85cfe
DecisionTreeFactor apply methods
varunagrawal Jul 22, 2023
3d24d01
efficient probabilities method
varunagrawal Jul 22, 2023
5f83464
use existing cardinalities
varunagrawal Jul 22, 2023
ad84163
use discrete base class in getting discrete factors
varunagrawal Jul 23, 2023
52f26e3
update TableFactor to use new version of DT probabilities
varunagrawal Jul 23, 2023
2df3cc8
undo previous changes
varunagrawal Jul 23, 2023
a4462a0
undo some more
varunagrawal Jul 23, 2023
381c33c
Merge branch 'develop' into hybrid-tablefactor-3
varunagrawal Jul 23, 2023
62d020a
remove duplicate definition
varunagrawal Jul 23, 2023
df0c5d7
remove timers
varunagrawal Jul 24, 2023
cb3c35b
refactor and better document prune method
varunagrawal Jul 25, 2023
3a78499
undo TableFactor changes
varunagrawal Jul 25, 2023
8c9fad8
undo more changes in TableFactor
varunagrawal Jul 25, 2023
ff39946
add new TableFactor constructors
varunagrawal Jul 25, 2023
e649fc6
Merge pull request #1592 from borglab/tablefactor-improvements
varunagrawal Jul 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ namespace gtsam {
ADT::print("", formatter);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
Expand All @@ -101,14 +117,6 @@ namespace gtsam {
return DecisionTreeFactor(keys, result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
Expand Down Expand Up @@ -188,10 +196,45 @@ namespace gtsam {

/* ************************************************************************ */
std::vector<double> DecisionTreeFactor::probabilities() const {
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());

std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}

/* An operation that takes each leaf probability, and computes the
* nrAssignments by checking the difference between the keys in the factor
* and the keys in the assignment.
* The nrAssignments is then used to append
* the correct number of leaf probability values to the `probs` vector
* defined above.
*/
auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : a) {
assignment_keys.insert(k);
}

// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));

// Compute the total number of assignments in the (pruned) subtree
size_t nrAssignments = 1;
for (auto&& k : diff) {
nrAssignments *= cardinalities_.at(k);
}
// Add p `nrAssignments` times to the probs vector.
probs.insert(probs.end(), nrAssignments, p);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is happening here? probs is a vector, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is adding the value p to probs, nrAssignments times.


return p;
};

// Go through the tree
this->apply(op);

return probs;
}

Expand Down Expand Up @@ -305,11 +348,7 @@ namespace gtsam {
const size_t N = maxNrAssignments;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}
std::vector<double> probabilities = this->probabilities();

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
Expand Down
7 changes: 7 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ namespace gtsam {
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::Unary op) const;

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;

/**
Expand Down
2 changes: 0 additions & 2 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(discreteProbs);
Expand Down
7 changes: 1 addition & 6 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;

Expand All @@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
discrete_factor_idxs.push_back(i);
}
}

const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);

// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

Expand All @@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

return prunedDiscreteProbs;
}
Expand All @@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {

HybridBayesNet prunedBayesNetFragment;

gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
Expand All @@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional);
}
}
gttoc_(HybridBayesNet_PruneMixtures);

return prunedBayesNetFragment;
}
Expand Down
3 changes: 0 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic_(assembleGraphTree);

GaussianFactorGraphTree result;

Expand Down Expand Up @@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
}
}

gttoc_(assembleGraphTree);

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
assert(df);
discrete_fg.push_back(df);
}
Expand Down