Skip to content

Commit

Permalink
implement a min-heap to record the top N probabilities for pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Nov 6, 2024
1 parent d21f191 commit 9666725
Showing 1 changed file with 60 additions and 14 deletions.
74 changes: 60 additions & 14 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,67 @@ namespace gtsam {
: DiscreteFactor(keys.indices(), keys.cardinalities()),
AlgebraicDecisionTree<Key>(keys, table) {}

/**
* @brief Min-Heap class to help with pruning.
* The `top` element is always the smallest value.
*/
class MinHeap {
std::vector<double> v_;

public:
/// Default constructor
MinHeap() {}

/// Push value onto the heap
void push(double x) {
v_.push_back(x);
std::make_heap(v_.begin(), v_.end(), std::greater<double>{});
}

/// Push value `x`, `n` number of times.
void push(double x, size_t n) {
v_.insert(v_.end(), n, x);
std::make_heap(v_.begin(), v_.end(), std::greater<double>{});
}

/// Pop the top value of the heap.
double pop() {
std::pop_heap(v_.begin(), v_.end(), std::greater<double>{});
double x = v_.back();
v_.pop_back();
return x;
}

/// Return the top value of the heap without popping it.
double top() { return v_.at(0); }

/**
* @brief Print the heap as a sequence.
*
* @param s A string to prologue the output.
*/
void print(const std::string& s = "") {
std::cout << (s.empty() ? "" : s + " ");
for (size_t i = 0; i < v_.size() - 1; i++) {
std::cout << v_.at(i) << ",";
}
std::cout << v_.at(v_.size() - 1) << std::endl;
}

/// Return true if heap is empty.
bool empty() const { return v_.empty(); }

/// Return the size of the heap.
size_t size() const { return v_.size(); }
};

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrAssignments) const {
const size_t N = maxNrAssignments;

// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());
std::vector<double> min_heap;
MinHeap min_heap;

auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
Expand All @@ -377,33 +431,25 @@ namespace gtsam {
}

if (min_heap.empty()) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
min_heap.push_back(p);
}
std::make_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
min_heap.push(p, std::min(nrAssignments, N));

} else {
// If p is larger than the smallest element,
// then we insert into the max heap.
if (p > min_heap.at(0)) {
if (p > min_heap.top()) {
for (size_t i = 0; i < std::min(nrAssignments, N); ++i) {
if (min_heap.size() == N) {
std::pop_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
min_heap.pop_back();
min_heap.pop();
}
min_heap.push_back(p);
std::make_heap(min_heap.begin(), min_heap.end(),
std::greater<double>{});
min_heap.push(p);
}
}
}
return p;
};
this->visitWith(op);

double threshold = min_heap.at(0);
double threshold = min_heap.top();

// Now threshold the decision tree
size_t total = 0;
Expand Down

0 comments on commit 9666725

Please sign in to comment.