Skip to content

Commit

Permalink
prune joint discrete probability which is faster
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jul 17, 2023
1 parent a2ed791 commit e47531e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 66 deletions.
85 changes: 36 additions & 49 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
// The joint discrete probability.
DiscreteConditional discreteProbs;

for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discreteProbs = discreteProbs * (*conditional->asDiscrete());
}
}
return std::make_shared<DiscreteConditional>(discreteProbs);
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
Expand Down Expand Up @@ -139,52 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) {
// TODO(Varun) Should prune the joint conditional, maybe during elimination?
// Loop with index since we need it later.
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;

std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();

// Convert pointer from conditional to factor
auto discreteFactor =
std::dynamic_pointer_cast<DecisionTreeFactor>(discrete);
// Apply prunerFunc to the underlying conditional
DecisionTreeFactor::ADT prunedDiscreteFactor =
discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional));

gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);

// Add it back to the BayesNet
this->at(i) = conditional;
discreteProbs = discreteProbs * (*conditional->asDiscrete());

Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DiscreteConditional::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);

// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

/* To Prune, we visitWith every leaf in the GaussianMixture.
return prunedDiscreteProbs;
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);

/* To prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
Expand Down
13 changes: 3 additions & 10 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr discreteConditionals() const;

/**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
Expand Down Expand Up @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

private:
/**
* @brief Update the discrete conditionals with the pruned versions.
* @brief Prune all the discrete conditionals.
*
* @param prunedDiscreteProbs
* @param maxNrLeaves
*/
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
Expand Down
26 changes: 19 additions & 7 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());

// Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));

Expand All @@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues);

// Regression
double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(density,
1.6078460548731697 * actualTree(discrete_values), 1e-6);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9);
Expand Down Expand Up @@ -283,10 +285,16 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(7, posterior->size());

size_t maxNrLeaves = 3;
auto discreteConditionals = posterior->discreteConditionals();
DiscreteConditional discreteConditionals;
for (auto&& conditional : *posterior) {
if (conditional->isDiscrete()) {
discreteConditionals =
discreteConditionals * (*conditional->asDiscrete());
}
}
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
std::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
discreteConditionals.prune(maxNrLeaves));

#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
Expand All @@ -295,20 +303,24 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
#endif

auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
// regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);

// Prune!
posterior->prune(maxNrLeaves);

// Functor to verify values against the original_discrete_conditionals
// Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues choices(assignment);
if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
1e-9);
}
return 0.0;
Expand Down

0 comments on commit e47531e

Please sign in to comment.