Skip to content

Commit

Permalink
Implement visitor methods on most data structures
Browse files Browse the repository at this point in the history
Summary:
I have noticed that patricia tree iterators internally use a stack, which requires allocating memory on the heap.
This can be detrimental to performance when using heavily nested maps, and iterating on the whole structure. This leads to a lot of allocations and deallocations.
To prevent this problem, this diff implements `visit` methods on most data structures.
`visit` uses recursive functions, which means it never allocate memory.

Reviewed By: arnaudvenet

Differential Revision: D49922681

fbshipit-source-id: 00de545e61e17dfb2b87501d0c9c943c2e5f7cf5
  • Loading branch information
arthaud authored and facebook-github-bot committed Oct 6, 2023
1 parent 0b16299 commit 483d9b8
Show file tree
Hide file tree
Showing 18 changed files with 186 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/sparta/FlatMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ class FlatMap final {
}
}

template <typename Visitor> // void(const value_type&)
void visit(Visitor&& visitor) const {
for (const auto& binding : m_map) {
visitor(binding);
}
}

template <typename Predicate> // bool(const Key&, const ValueType&)
FlatMap& filter(Predicate&& predicate) {
switch (m_map.size()) {
Expand Down
7 changes: 7 additions & 0 deletions include/sparta/FlatSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ class FlatSet final {
return *this;
}

template <typename Visitor> // void (const Element&)
void visit(Visitor&& visitor) const {
for (const auto& element : m_set) {
visitor(element);
}
}

template <typename Predicate> // bool(const Element&)
FlatSet& filter(Predicate&& predicate) {
auto container = m_set.extract_sequence();
Expand Down
7 changes: 7 additions & 0 deletions include/sparta/HashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ class HashMap final {
return *this;
}

template <typename Visitor> // void(const value_type&)
void visit(Visitor&& visitor) const {
for (const auto& binding : m_map) {
visitor(binding);
}
}

template <typename Predicate> // bool(const Key&, const ValueType&)
HashMap& filter(Predicate&& predicate) {
auto it = m_map.begin(), end = m_map.end();
Expand Down
8 changes: 8 additions & 0 deletions include/sparta/HashedAbstractEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ class HashedAbstractEnvironment final
return *this;
}

template <typename Visitor> // void(const std::pair<Variable, Domain>&)
void visit(Visitor&& visitor) const {
if (this->is_bottom()) {
return;
}
this->get_value()->m_map.visit(std::forward<Visitor>(visitor));
}

static HashedAbstractEnvironment bottom() {
return HashedAbstractEnvironment(AbstractValueKind::Bottom);
}
Expand Down
8 changes: 8 additions & 0 deletions include/sparta/HashedAbstractPartition.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ class HashedAbstractPartition final
m_map.transform(std::forward<Operation>(f));
}

template <typename Visitor> // void(const std::pair<Label, Domain>&)
void visit(Visitor&& visitor) const {
if (is_top()) {
return;
}
m_map.visit(std::forward<Visitor>(visitor));
}

bool is_top() const { return m_is_top; }

bool is_bottom() const { return !m_is_top && m_map.empty(); }
Expand Down
24 changes: 24 additions & 0 deletions include/sparta/PatriciaTreeCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,21 @@ inline intrusive_ptr<PatriciaTreeNode<IntegerType, Value>> update_all_leafs(
}
}

template <typename IntegerType, typename Value, typename Visitor>
inline void visit_all_leafs(
Visitor&& visitor,
const intrusive_ptr<PatriciaTreeNode<IntegerType, Value>>& tree) {
if (tree == nullptr) {
return;
} else if (const auto* leaf = tree->as_leaf(); leaf != nullptr) {
visitor(leaf->data());
} else {
const auto* branch = tree->as_branch();
visit_all_leafs(visitor, branch->left_tree());
visit_all_leafs(visitor, branch->right_tree());
}
}

template <typename IntegerType, typename Value, typename LeafCombine>
inline intrusive_ptr<PatriciaTreeLeaf<IntegerType, Value>> combine_leafs(
LeafCombine&& leaf_combine,
Expand Down Expand Up @@ -1408,6 +1423,15 @@ class PatriciaTreeCore {
return m_tree != old_tree;
}

template <typename Visitor>
inline void visit_all_leafs(Visitor&& visitor) const {
pt_core::visit_all_leafs(
[visitor = std::forward<Visitor>(visitor)](const auto& data) {
visitor(Codec::decode(data));
},
m_tree);
}

template <typename LeafCombine>
inline void merge(LeafCombine&& leaf_combine, const PatriciaTreeCore& other) {
m_tree = pt_core::merge_trees(std::forward<LeafCombine>(leaf_combine),
Expand Down
12 changes: 12 additions & 0 deletions include/sparta/PatriciaTreeHashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ class PatriciaTreeHashMap final {
});
}

/*
* Visit all key-value pairs.
* This does NOT allocate memory, unlike the iterators.
*/
template <typename Visitor> // void(const value_type&)
void visit(Visitor&& visitor) const {
m_tree.visit([visitor = std::forward<Visitor>(visitor)](
const std::pair<size_t, FlatMapT>& binding) {
binding.second.visit(visitor);
});
}

PatriciaTreeHashMap& remove(const Key& key) {
m_tree.update(
[&key](FlatMapT flat_map) -> FlatMapT {
Expand Down
13 changes: 13 additions & 0 deletions include/sparta/PatriciaTreeHashMapAbstractEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ class PatriciaTreeHashMapAbstractEnvironment final
return res;
}

template <typename Visitor> // void(const std::pair<Variable, Domain>&)
void visit(Visitor&& visitor) const {
if (this->is_bottom()) {
return;
}
this->get_value()->visit(std::forward<Visitor>(visitor));
}

PatriciaTreeHashMapAbstractEnvironment& clear() {
if (this->is_bottom()) {
return *this;
Expand Down Expand Up @@ -276,6 +284,11 @@ class MapValue final : public AbstractValue<MapValue<Variable, Domain>> {
return m_map.transform(std::forward<Operation>(f));
}

template <typename Visitor> // void(const std::pair<Variable, Domain>&)
void visit(Visitor&& visitor) const {
return m_map.visit(std::forward<Visitor>(visitor));
}

template <typename Operation> // void(Domain*, const Domain&)
AbstractValueKind join_like_operation(const MapValue& other,
Operation&& operation) {
Expand Down
8 changes: 8 additions & 0 deletions include/sparta/PatriciaTreeHashMapAbstractPartition.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ class PatriciaTreeHashMapAbstractPartition final
return m_map.transform(std::forward<Operation>(f));
}

template <typename Visitor> // void(const std::pair<Label, Domain>&)
void visit(Visitor&& visitor) const {
if (is_top()) {
return;
}
m_map.visit(std::forward<Visitor>(visitor));
}

bool is_top() const { return m_is_top; }

bool is_bottom() const { return !m_is_top && m_map.empty(); }
Expand Down
9 changes: 9 additions & 0 deletions include/sparta/PatriciaTreeMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ class PatriciaTreeMap final {
apply_leafs(std::forward<MappingFunction>(f)));
}

/*
* Visit all key-value pairs.
* This does NOT allocate memory, unlike the iterators.
*/
template <typename Visitor> // void(const value_type&)
void visit(Visitor&& visitor) const {
m_core.visit_all_leafs(std::forward<Visitor>(visitor));
}

PatriciaTreeMap& remove(Key key) {
m_core.remove(key);
return *this;
Expand Down
13 changes: 13 additions & 0 deletions include/sparta/PatriciaTreeMapAbstractEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ class PatriciaTreeMapAbstractEnvironment final
return res;
}

template <typename Visitor> // void(const std::pair<Variable, Domain>&)
void visit(Visitor&& visitor) const {
if (this->is_bottom()) {
return;
}
this->get_value()->visit(std::forward<Visitor>(visitor));
}

bool erase_all_matching(const Variable& variable_mask) {
if (this->is_bottom()) {
return false;
Expand Down Expand Up @@ -289,6 +297,11 @@ class MapValue final : public AbstractValue<MapValue<Variable, Domain>> {
return m_map.transform(std::forward<Operation>(f));
}

template <typename Visitor> // void(const std::pair<Variable, Domain>&)
void visit(Visitor&& visitor) const {
return m_map.visit(std::forward<Visitor>(visitor));
}

bool erase_all_matching(const Variable& variable_mask) {
return m_map.erase_all_matching(variable_mask);
}
Expand Down
8 changes: 8 additions & 0 deletions include/sparta/PatriciaTreeMapAbstractPartition.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ class PatriciaTreeMapAbstractPartition final
return m_map.transform(std::forward<Operation>(f));
}

template <typename Visitor> // void(const std::pair<Label, Domain>&)
void visit(Visitor&& visitor) const {
if (is_top()) {
return;
}
m_map.visit(std::forward<Visitor>(visitor));
}

bool is_top() const { return m_is_top; }

bool is_bottom() const { return !m_is_top && m_map.empty(); }
Expand Down
9 changes: 9 additions & 0 deletions include/sparta/PatriciaTreeSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ class PatriciaTreeSet final {
return *this;
}

/*
* Visit all elements.
* This does NOT allocate memory, unlike the iterators.
*/
template <typename Visitor> // void(const Element&)
void visit(Visitor&& visitor) const {
m_core.visit_all_leafs(std::forward<Visitor>(visitor));
}

bool erase_all_matching(Element element_mask) {
return m_core.erase_all_matching(element_mask);
}
Expand Down
11 changes: 11 additions & 0 deletions test/FlatMapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,14 @@ TEST_F(FlatMapTest, difference) {
IntFlatMap({{2, 1}, {4, 1}, {6, 1}})),
IntFlatMap({{1, 3}, {3, 3}, {5, 3}}));
}

TEST_F(FlatMapTest, visit) {
auto m = IntFlatMap({
{1, 2},
{2, 3},
{4, 5},
});
size_t sum = 0;
m.visit([&sum](const auto& binding) { sum += binding.second; });
EXPECT_EQ(sum, 10);
}
10 changes: 10 additions & 0 deletions test/HashedAbstractEnvironmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ TEST(HashedAbstractEnvironmentTest, destructiveOperations) {
EXPECT_THAT(e2.get("v3").elements(),
::testing::UnorderedElementsAre("g", "h"));
}

TEST(HashedAbstractEnvironmentTest, visit) {
Environment e1({{"v1", Domain({"a", "b"})},
{"v2", Domain("c")},
{"v3", Domain({"d", "e", "f"})},
{"v4", Domain({"a", "f"})}});
auto all = Domain{};
e1.visit([&all](const auto& binding) { all.join_with(binding.second); });
EXPECT_EQ(all, Domain({"a", "b", "c", "d", "e", "f"}));
}
10 changes: 10 additions & 0 deletions test/HashedAbstractPartitionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,13 @@ TEST(HashedAbstractPartitionTest, difference) {
{{"b", Domain{"a"}}, {"d", Domain{"a"}}, {"e", Domain{"c"}}})),
Partition({{"a", Domain{"c"}}, {"c", Domain{"c"}}, {"d", Domain{"c"}}}));
}

TEST(HashedAbstractPartitionTest, visit) {
Partition p1({{"v1", Domain({"a", "b"})},
{"v2", Domain("c")},
{"v3", Domain({"d", "e", "f"})},
{"v4", Domain({"a", "f"})}});
auto all = Domain{};
p1.visit([&all](const auto& binding) { all.join_with(binding.second); });
EXPECT_EQ(all, Domain({"a", "b", "c", "d", "e", "f"}));
}
11 changes: 11 additions & 0 deletions test/PatriciaTreeMapAbstractEnvironmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,17 @@ TEST_F(PatriciaTreeMapAbstractEnvironmentTest, transform) {
EXPECT_TRUE(e1.is_top());
}

TEST_F(PatriciaTreeMapAbstractEnvironmentTest, visit) {
Environment e1({
{1, Domain({"a", "b"})},
{2, Domain({"a", "b"})},
{3, Domain({"a", "b"})},
});
size_t sum = 0;
e1.visit([&sum](const auto& binding) { sum += binding.first; });
EXPECT_EQ(sum, 6);
}

TEST_F(PatriciaTreeMapAbstractEnvironmentTest, prettyPrinting) {
using StringEnvironment =
PatriciaTreeMapAbstractEnvironment<std::string*, Domain>;
Expand Down
11 changes: 11 additions & 0 deletions test/PatriciaTreeMapAbstractPartitionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,14 @@ TEST(PatriciaTreeMapAbstractPartitionTest, transform) {
EXPECT_TRUE(any_changes);
EXPECT_TRUE(p1.is_bottom());
}

TEST(PatriciaTreeMapAbstractPartitionTest, visit) {
Partition p1({
{1, Domain({"a", "b"})},
{2, Domain({"a", "b"})},
{3, Domain({"a", "b"})},
});
size_t sum = 0;
p1.visit([&sum](const auto& binding) { sum += binding.first; });
EXPECT_EQ(sum, 6);
}

0 comments on commit 483d9b8

Please sign in to comment.