Skip to content

Commit

Permalink
Fix perfect forward of lambdas
Browse files Browse the repository at this point in the history
Summary:
We are often using `[f = std::forward<F>(f)]` in lambdas, in the goal of perfect-forwarding a callable `f`.
Unfortunately, this does NOT actually perfectly-forward the value. When f is a reference, this leads to an extra copy, because the initializer expression in a generalized lambda capture does not have effect on the deduced type of the capture.
This is better explained by Vittorio Romeo: https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html

To fix this issue, we can use a helper function `fwd_capture`.
Note that this also allows us to use mutable lambdas.

Reviewed By: arnaudvenet

Differential Revision: D60187566

fbshipit-source-id: f19688f37a61219b5fb902a27816d2c66792b250
  • Loading branch information
arthaud authored and facebook-github-bot committed Jul 25, 2024
1 parent 850b004 commit 2e20749
Show file tree
Hide file tree
Showing 10 changed files with 255 additions and 70 deletions.
25 changes: 12 additions & 13 deletions include/sparta/AbstractEnvironment.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <sparta/AbstractDomain.h>
#include <sparta/AbstractMap.h>
#include <sparta/AbstractMapValue.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {

Expand Down Expand Up @@ -135,9 +136,9 @@ class AbstractEnvironment final
try {
if constexpr (Map::mutability == AbstractMapMutability::Immutable) {
this->get_value()->m_map.update(
[operation = std::forward<Operation>(operation)](
const Domain& value) -> Domain {
Domain result = operation(value);
[operation = fwd_capture(std::forward<Operation>(operation))](
const Domain& value) mutable -> Domain {
Domain result = operation.get()(value);
if (result.is_bottom()) {
throw environment_impl::value_is_bottom();
}
Expand All @@ -146,9 +147,9 @@ class AbstractEnvironment final
variable);
} else if constexpr (Map::mutability == AbstractMapMutability::Mutable) {
this->get_value()->m_map.update(
[operation =
std::forward<Operation>(operation)](Domain* value) -> void {
operation(value);
[operation = fwd_capture(std::forward<Operation>(operation))](
Domain* value) mutable -> void {
operation.get()(value);
if (value->is_bottom()) {
throw environment_impl::value_is_bottom();
}
Expand Down Expand Up @@ -338,19 +339,18 @@ class MapValue final : public AbstractValue<MapValue<Map>> {

template <typename Operation>
AbstractValueKind join_like_operation(const MapValue& other,
Operation&& operation) {
m_map.intersection_with(std::forward<Operation>(operation), other.m_map);
const Operation& operation) {
m_map.intersection_with(operation, other.m_map);
return kind();
}

template <typename Operation>
AbstractValueKind meet_like_operation(const MapValue& other,
Operation&& operation) {
const Operation& operation) {
try {
if constexpr (Map::mutability == AbstractMapMutability::Immutable) {
m_map.union_with(
[operation = std::forward<Operation>(operation)](
const Domain& x, const Domain& y) -> Domain {
[&operation](const Domain& x, const Domain& y) -> Domain {
Domain result = operation(x, y);
if (result.is_bottom()) {
throw value_is_bottom();
Expand All @@ -360,8 +360,7 @@ class MapValue final : public AbstractValue<MapValue<Map>> {
other.m_map);
} else if constexpr (Map::mutability == AbstractMapMutability::Mutable) {
m_map.union_with(
[operation = std::forward<Operation>(operation)](
Domain* x, const Domain& y) -> void {
[&operation](Domain* x, const Domain& y) -> void {
operation(x, y);
if (x->is_bottom()) {
throw value_is_bottom();
Expand Down
13 changes: 7 additions & 6 deletions include/sparta/DirectProductAbstractDomain.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <type_traits>

#include <sparta/AbstractDomain.h>
#include <sparta/PerfectForwardCapture.h>

// Forward declarations.
namespace sparta {
Expand Down Expand Up @@ -190,10 +191,10 @@ class DirectProductAbstractDomain : public AbstractDomain<Derived> {
template <class Predicate>
bool all_of(Predicate&& predicate) const {
return tuple_apply(
[predicate =
std::forward<Predicate>(predicate)](const Domains&... component) {
[predicate = fwd_capture(std::forward<Predicate>(predicate))](
const Domains&... component) mutable {
bool result = true;
discard({(result &= predicate(component))...});
discard({(result &= predicate.get()(component))...});
return result;
},
m_product);
Expand All @@ -202,10 +203,10 @@ class DirectProductAbstractDomain : public AbstractDomain<Derived> {
template <class Predicate>
bool any_of(Predicate&& predicate) const {
return tuple_apply(
[predicate =
std::forward<Predicate>(predicate)](const Domains&... component) {
[predicate = fwd_capture(std::forward<Predicate>(predicate))](
const Domains&... component) mutable {
bool result = false;
discard({(result |= predicate(component))...});
discard({(result |= predicate.get()(component))...});
return result;
},
m_product);
Expand Down
16 changes: 9 additions & 7 deletions include/sparta/FlatMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <sparta/AbstractMap.h>
#include <sparta/AbstractMapValue.h>
#include <sparta/PatriciaTreeCore.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {
namespace fm_impl {
Expand Down Expand Up @@ -282,13 +283,14 @@ class FlatMap final
// Use boost `flat_map` API to get the underlying container and
// apply a remove_if + erase. This allows to perform a filter in O(n).
auto container = m_map.extract_sequence();
container.erase(std::remove_if(container.begin(),
container.end(),
[predicate = std::forward<Predicate>(
predicate)](const auto& p) {
return !predicate(p.first, p.second);
}),
container.end());
container.erase(
std::remove_if(container.begin(),
container.end(),
[predicate = fwd_capture(std::forward<Predicate>(
predicate))](const auto& p) mutable {
return !predicate.get()(p.first, p.second);
}),
container.end());
m_map.adopt_sequence(boost::container::ordered_unique_range,
std::move(container));
break;
Expand Down
7 changes: 5 additions & 2 deletions include/sparta/FlatSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <sparta/AbstractSet.h>
#include <sparta/PatriciaTreeUtil.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {

Expand Down Expand Up @@ -131,8 +132,10 @@ class FlatSet final
FlatSet& filter(Predicate&& predicate) {
auto container = m_set.extract_sequence();
container.erase(
std::remove_if(container.begin(), container.end(),
[&](const Element& e) { return !predicate(e); }),
std::remove_if(
container.begin(), container.end(),
[predicate = fwd_capture(std::forward<Predicate>(predicate))](
const Element& e) mutable { return !predicate.get()(e); }),
container.end());
m_set.adopt_sequence(boost::container::ordered_unique_range,
std::move(container));
Expand Down
30 changes: 16 additions & 14 deletions include/sparta/PatriciaTreeCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <sparta/AbstractMapValue.h>
#include <sparta/Exceptions.h>
#include <sparta/PatriciaTreeUtil.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {

Expand Down Expand Up @@ -864,7 +865,7 @@ inline intrusive_ptr<PatriciaTreeNode<IntegerType, Value>> update_leaf_by_key(
LeafOperation&& leaf_operation,
IntegerType key,
const intrusive_ptr<PatriciaTreeNode<IntegerType, Value>>& tree) {
const auto make_new_leaf = [&] {
const auto make_new_leaf = [key, &leaf_operation]() {
return update_new_leaf<IntegerType, Value>(
std::forward<LeafOperation>(leaf_operation), key);
};
Expand Down Expand Up @@ -936,7 +937,9 @@ inline intrusive_ptr<PatriciaTreeNode<IntegerType, Value>> upsert_leaf_by_key(
ValueOrLeaf value_or_leaf,
const intrusive_ptr<PatriciaTreeNode<IntegerType, Value>>& tree) {
return update_leaf_by_key(
[&](const auto&) { return std::move(value_or_leaf); }, key, tree);
[&value_or_leaf](const auto&) { return std::move(value_or_leaf); },
key,
tree);
}

template <typename IntegerType, typename Value, typename LeafOperation>
Expand Down Expand Up @@ -986,9 +989,9 @@ inline intrusive_ptr<PatriciaTreeLeaf<IntegerType, Value>> combine_leafs(
intrusive_ptr<PatriciaTreeLeaf<IntegerType, Value>> other,
intrusive_ptr<PatriciaTreeLeaf<IntegerType, Value>> leaf) {
return update_leaf(
[leaf_combine = std::forward<LeafCombine>(leaf_combine),
&other](auto leaf) {
return leaf_combine(std::move(leaf), std::move(other));
[leaf_combine = fwd_capture(std::forward<LeafCombine>(leaf_combine)),
&other](auto leaf) mutable {
return leaf_combine.get()(std::move(leaf), std::move(other));
},
std::move(leaf));
}
Expand All @@ -1000,9 +1003,9 @@ inline intrusive_ptr<PatriciaTreeNode<IntegerType, Value>> combine_leafs_by_key(
IntegerType key,
const intrusive_ptr<PatriciaTreeNode<IntegerType, Value>>& tree) {
return update_leaf_by_key(
[leaf_combine = std::forward<LeafCombine>(leaf_combine),
&other](auto leaf) {
return leaf_combine(std::move(leaf), std::move(other));
[leaf_combine = fwd_capture(std::forward<LeafCombine>(leaf_combine)),
&other](auto leaf) mutable {
return leaf_combine.get()(std::move(leaf), std::move(other));
},
key,
tree);
Expand Down Expand Up @@ -1440,9 +1443,8 @@ class PatriciaTreeCore {
template <typename Visitor>
inline void visit_all_leafs(Visitor&& visitor) const {
pt_core::visit_all_leafs(
[visitor = std::forward<Visitor>(visitor)](const auto& data) {
visitor(Codec::decode(data));
},
[visitor = fwd_capture(std::forward<Visitor>(visitor))](
const auto& data) mutable { visitor.get()(Codec::decode(data)); },
m_tree);
}

Expand Down Expand Up @@ -1472,9 +1474,9 @@ class PatriciaTreeCore {
template <typename Predicate>
inline void filter(Predicate&& predicate) {
m_tree = pt_core::filter_tree(
[predicate = std::forward<Predicate>(predicate)](
IntegerType key, const ValueType& value) {
return predicate(Codec::decode(key), value);
[predicate = fwd_capture(std::forward<Predicate>(predicate))](
IntegerType key, const ValueType& value) mutable {
return predicate.get()(Codec::decode(key), value);
},
m_tree);
}
Expand Down
54 changes: 28 additions & 26 deletions include/sparta/PatriciaTreeHashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <sparta/FlatMap.h>
#include <sparta/FlattenIterator.h>
#include <sparta/PatriciaTreeMap.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {
namespace pthm_impl {
Expand Down Expand Up @@ -163,12 +164,9 @@ class PatriciaTreeHashMap final
template <typename Operation> // void(mapped_type*)
PatriciaTreeHashMap& update(Operation&& operation, const Key& key) {
m_tree.update(
[operation = std::forward<Operation>(operation),
&key](FlatMapT flat_map) -> FlatMapT {
// We should be using `std::forward` here but we would get a compiler
// error, see:
// https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html
flat_map.update(operation, key);
[operation = fwd_capture(std::forward<Operation>(operation)),
&key](FlatMapT flat_map) mutable -> FlatMapT {
flat_map.update(std::forward<Operation>(operation.get()), key);
return flat_map;
},
KeyHash()(key));
Expand All @@ -177,11 +175,11 @@ class PatriciaTreeHashMap final

template <typename MappingFunction> // void(mapped_type*)
void transform(MappingFunction&& f) {
m_tree.transform(
[f = std::forward<MappingFunction>(f)](FlatMapT flat_map) -> FlatMapT {
flat_map.transform(f);
return flat_map;
});
m_tree.transform([f = fwd_capture(std::forward<MappingFunction>(f))](
FlatMapT flat_map) mutable -> FlatMapT {
flat_map.transform(f.get());
return flat_map;
});
}

/*
Expand All @@ -190,9 +188,9 @@ class PatriciaTreeHashMap final
*/
template <typename Visitor> // void(const value_type&)
void visit(Visitor&& visitor) const {
m_tree.visit([visitor = std::forward<Visitor>(visitor)](
const std::pair<size_t, FlatMapT>& binding) {
binding.second.visit(visitor);
m_tree.visit([visitor = fwd_capture(std::forward<Visitor>(visitor))](
const std::pair<size_t, FlatMapT>& binding) mutable {
binding.second.visit(visitor.get());
});
}

Expand All @@ -207,9 +205,10 @@ class PatriciaTreeHashMap final
}

template <typename Predicate> // bool(const Key&, const mapped_type&)
PatriciaTreeHashMap& filter(const Predicate& predicate) {
m_tree.transform([&predicate](FlatMapT flat_map) -> FlatMapT {
flat_map.filter(predicate);
PatriciaTreeHashMap& filter(Predicate&& predicate) {
m_tree.transform([predicate = fwd_capture(std::forward<Predicate>(
predicate))](FlatMapT flat_map) mutable -> FlatMapT {
flat_map.filter(predicate.get());
return flat_map;
});
return *this;
Expand All @@ -222,11 +221,12 @@ class PatriciaTreeHashMap final
// Requires CombiningFunction to coerce to
// std::function<void(mapped_type*, const mapped_type&)>
template <typename CombiningFunction>
PatriciaTreeHashMap& union_with(const CombiningFunction& combine,
PatriciaTreeHashMap& union_with(CombiningFunction&& combine,
const PatriciaTreeHashMap& other) {
m_tree.union_with(
[&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT {
left.union_with(combine, right);
[combine = fwd_capture(std::forward<CombiningFunction>(combine))](
FlatMapT left, const FlatMapT& right) mutable -> FlatMapT {
left.union_with(combine.get(), right);
return left;
},
other.m_tree);
Expand All @@ -236,11 +236,12 @@ class PatriciaTreeHashMap final
// Requires CombiningFunction to coerce to
// std::function<void(mapped_type*, const mapped_type&)>
template <typename CombiningFunction>
PatriciaTreeHashMap& intersection_with(const CombiningFunction& combine,
PatriciaTreeHashMap& intersection_with(CombiningFunction&& combine,
const PatriciaTreeHashMap& other) {
m_tree.intersection_with(
[&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT {
left.intersection_with(combine, right);
[combine = fwd_capture(std::forward<CombiningFunction>(combine))](
FlatMapT left, const FlatMapT& right) mutable -> FlatMapT {
left.intersection_with(combine.get(), right);
return left;
},
other.m_tree);
Expand All @@ -249,11 +250,12 @@ class PatriciaTreeHashMap final

// Requires that `combine(bottom, ...) = bottom`.
template <typename CombiningFunction>
PatriciaTreeHashMap& difference_with(const CombiningFunction& combine,
PatriciaTreeHashMap& difference_with(CombiningFunction&& combine,
const PatriciaTreeHashMap& other) {
m_tree.difference_with(
[&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT {
left.difference_with(combine, right);
[combine = fwd_capture(std::forward<CombiningFunction>(combine))](
FlatMapT left, const FlatMapT& right) mutable -> FlatMapT {
left.difference_with(combine.get(), right);
return left;
},
other.m_tree);
Expand Down
6 changes: 4 additions & 2 deletions include/sparta/PatriciaTreeMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <sparta/Exceptions.h>
#include <sparta/PatriciaTreeCore.h>
#include <sparta/PatriciaTreeUtil.h>
#include <sparta/PerfectForwardCapture.h>

namespace sparta {
namespace ptm_impl {
Expand Down Expand Up @@ -237,10 +238,11 @@ class PatriciaTreeMap final
// This wraps the given function to apply these transformations.
template <typename Func>
inline static auto apply_leafs(Func&& func) {
return [func = std::forward<Func>(func)](const auto&... leaf_ptrs) {
return [func = fwd_capture(std::forward<Func>(func))](
const auto&... leaf_ptrs) mutable {
auto default_value = ValueInterface::default_value();
auto return_value =
func((leaf_ptrs ? leaf_ptrs->value() : default_value)...);
func.get()((leaf_ptrs ? leaf_ptrs->value() : default_value)...);

return keep_if_non_default(std::move(return_value));
};
Expand Down
Loading

0 comments on commit 2e20749

Please sign in to comment.