From 2e20749c23f8b8a680516b81279e7098a3dc63a9 Mon Sep 17 00:00:00 2001 From: Maxime Arthaud Date: Thu, 25 Jul 2024 09:46:38 -0700 Subject: [PATCH] Fix perfect forward of lambdas Summary: We are often using `[f = std::forward(f)]` in lambdas, in the goal of perfect-forwarding a callable `f`. Unfortunately, this does NOT actually perfectly-forward the value. When f is a reference, this leads to an extra copy, because the initializer expression in a generalized lambda capture does not have effect on the deduced type of the capture. This is better explained by Vittorio Romeo: https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html To fix this issue, we can use a helper function `fwd_capture`. Note that this also allows us to use mutable lambdas. Reviewed By: arnaudvenet Differential Revision: D60187566 fbshipit-source-id: f19688f37a61219b5fb902a27816d2c66792b250 --- include/sparta/AbstractEnvironment.h | 25 +++---- include/sparta/DirectProductAbstractDomain.h | 13 ++-- include/sparta/FlatMap.h | 16 ++-- include/sparta/FlatSet.h | 7 +- include/sparta/PatriciaTreeCore.h | 30 ++++---- include/sparta/PatriciaTreeHashMap.h | 54 +++++++------- include/sparta/PatriciaTreeMap.h | 6 +- include/sparta/PerfectForwardCapture.h | 78 ++++++++++++++++++++ test/PatriciaTreeHashMapTest.cpp | 47 ++++++++++++ test/PatriciaTreeMapTest.cpp | 49 ++++++++++++ 10 files changed, 255 insertions(+), 70 deletions(-) create mode 100644 include/sparta/PerfectForwardCapture.h diff --git a/include/sparta/AbstractEnvironment.h b/include/sparta/AbstractEnvironment.h index c0a7ab3..d102995 100644 --- a/include/sparta/AbstractEnvironment.h +++ b/include/sparta/AbstractEnvironment.h @@ -16,6 +16,7 @@ #include #include #include +#include namespace sparta { @@ -135,9 +136,9 @@ class AbstractEnvironment final try { if constexpr (Map::mutability == AbstractMapMutability::Immutable) { this->get_value()->m_map.update( - [operation = std::forward(operation)]( - const Domain& value) -> Domain { - Domain result = operation(value); + [operation = fwd_capture(std::forward(operation))]( + const Domain& value) mutable -> Domain { + Domain result = operation.get()(value); if (result.is_bottom()) { throw environment_impl::value_is_bottom(); } @@ -146,9 +147,9 @@ class AbstractEnvironment final variable); } else if constexpr (Map::mutability == AbstractMapMutability::Mutable) { this->get_value()->m_map.update( - [operation = - std::forward(operation)](Domain* value) -> void { - operation(value); + [operation = fwd_capture(std::forward(operation))]( + Domain* value) mutable -> void { + operation.get()(value); if (value->is_bottom()) { throw environment_impl::value_is_bottom(); } @@ -338,19 +339,18 @@ class MapValue final : public AbstractValue> { template AbstractValueKind join_like_operation(const MapValue& other, - Operation&& operation) { - m_map.intersection_with(std::forward(operation), other.m_map); + const Operation& operation) { + m_map.intersection_with(operation, other.m_map); return kind(); } template AbstractValueKind meet_like_operation(const MapValue& other, - Operation&& operation) { + const Operation& operation) { try { if constexpr (Map::mutability == AbstractMapMutability::Immutable) { m_map.union_with( - [operation = std::forward(operation)]( - const Domain& x, const Domain& y) -> Domain { + [&operation](const Domain& x, const Domain& y) -> Domain { Domain result = operation(x, y); if (result.is_bottom()) { throw value_is_bottom(); @@ -360,8 +360,7 @@ class MapValue final : public AbstractValue> { other.m_map); } else if constexpr (Map::mutability == AbstractMapMutability::Mutable) { m_map.union_with( - [operation = std::forward(operation)]( - Domain* x, const Domain& y) -> void { + [&operation](Domain* x, const Domain& y) -> void { operation(x, y); if (x->is_bottom()) { throw value_is_bottom(); diff --git a/include/sparta/DirectProductAbstractDomain.h b/include/sparta/DirectProductAbstractDomain.h index ed2b38a..34561a9 100644 --- a/include/sparta/DirectProductAbstractDomain.h +++ b/include/sparta/DirectProductAbstractDomain.h @@ -16,6 +16,7 @@ #include #include +#include // Forward declarations. namespace sparta { @@ -190,10 +191,10 @@ class DirectProductAbstractDomain : public AbstractDomain { template bool all_of(Predicate&& predicate) const { return tuple_apply( - [predicate = - std::forward(predicate)](const Domains&... component) { + [predicate = fwd_capture(std::forward(predicate))]( + const Domains&... component) mutable { bool result = true; - discard({(result &= predicate(component))...}); + discard({(result &= predicate.get()(component))...}); return result; }, m_product); @@ -202,10 +203,10 @@ class DirectProductAbstractDomain : public AbstractDomain { template bool any_of(Predicate&& predicate) const { return tuple_apply( - [predicate = - std::forward(predicate)](const Domains&... component) { + [predicate = fwd_capture(std::forward(predicate))]( + const Domains&... component) mutable { bool result = false; - discard({(result |= predicate(component))...}); + discard({(result |= predicate.get()(component))...}); return result; }, m_product); diff --git a/include/sparta/FlatMap.h b/include/sparta/FlatMap.h index 0e25e6d..af64712 100644 --- a/include/sparta/FlatMap.h +++ b/include/sparta/FlatMap.h @@ -17,6 +17,7 @@ #include #include #include +#include namespace sparta { namespace fm_impl { @@ -282,13 +283,14 @@ class FlatMap final // Use boost `flat_map` API to get the underlying container and // apply a remove_if + erase. This allows to perform a filter in O(n). auto container = m_map.extract_sequence(); - container.erase(std::remove_if(container.begin(), - container.end(), - [predicate = std::forward( - predicate)](const auto& p) { - return !predicate(p.first, p.second); - }), - container.end()); + container.erase( + std::remove_if(container.begin(), + container.end(), + [predicate = fwd_capture(std::forward( + predicate))](const auto& p) mutable { + return !predicate.get()(p.first, p.second); + }), + container.end()); m_map.adopt_sequence(boost::container::ordered_unique_range, std::move(container)); break; diff --git a/include/sparta/FlatSet.h b/include/sparta/FlatSet.h index 2e3323b..1935102 100644 --- a/include/sparta/FlatSet.h +++ b/include/sparta/FlatSet.h @@ -16,6 +16,7 @@ #include #include +#include namespace sparta { @@ -131,8 +132,10 @@ class FlatSet final FlatSet& filter(Predicate&& predicate) { auto container = m_set.extract_sequence(); container.erase( - std::remove_if(container.begin(), container.end(), - [&](const Element& e) { return !predicate(e); }), + std::remove_if( + container.begin(), container.end(), + [predicate = fwd_capture(std::forward(predicate))]( + const Element& e) mutable { return !predicate.get()(e); }), container.end()); m_set.adopt_sequence(boost::container::ordered_unique_range, std::move(container)); diff --git a/include/sparta/PatriciaTreeCore.h b/include/sparta/PatriciaTreeCore.h index 8a40893..f2337e9 100644 --- a/include/sparta/PatriciaTreeCore.h +++ b/include/sparta/PatriciaTreeCore.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace sparta { @@ -864,7 +865,7 @@ inline intrusive_ptr> update_leaf_by_key( LeafOperation&& leaf_operation, IntegerType key, const intrusive_ptr>& tree) { - const auto make_new_leaf = [&] { + const auto make_new_leaf = [key, &leaf_operation]() { return update_new_leaf( std::forward(leaf_operation), key); }; @@ -936,7 +937,9 @@ inline intrusive_ptr> upsert_leaf_by_key( ValueOrLeaf value_or_leaf, const intrusive_ptr>& tree) { return update_leaf_by_key( - [&](const auto&) { return std::move(value_or_leaf); }, key, tree); + [&value_or_leaf](const auto&) { return std::move(value_or_leaf); }, + key, + tree); } template @@ -986,9 +989,9 @@ inline intrusive_ptr> combine_leafs( intrusive_ptr> other, intrusive_ptr> leaf) { return update_leaf( - [leaf_combine = std::forward(leaf_combine), - &other](auto leaf) { - return leaf_combine(std::move(leaf), std::move(other)); + [leaf_combine = fwd_capture(std::forward(leaf_combine)), + &other](auto leaf) mutable { + return leaf_combine.get()(std::move(leaf), std::move(other)); }, std::move(leaf)); } @@ -1000,9 +1003,9 @@ inline intrusive_ptr> combine_leafs_by_key( IntegerType key, const intrusive_ptr>& tree) { return update_leaf_by_key( - [leaf_combine = std::forward(leaf_combine), - &other](auto leaf) { - return leaf_combine(std::move(leaf), std::move(other)); + [leaf_combine = fwd_capture(std::forward(leaf_combine)), + &other](auto leaf) mutable { + return leaf_combine.get()(std::move(leaf), std::move(other)); }, key, tree); @@ -1440,9 +1443,8 @@ class PatriciaTreeCore { template inline void visit_all_leafs(Visitor&& visitor) const { pt_core::visit_all_leafs( - [visitor = std::forward(visitor)](const auto& data) { - visitor(Codec::decode(data)); - }, + [visitor = fwd_capture(std::forward(visitor))]( + const auto& data) mutable { visitor.get()(Codec::decode(data)); }, m_tree); } @@ -1472,9 +1474,9 @@ class PatriciaTreeCore { template inline void filter(Predicate&& predicate) { m_tree = pt_core::filter_tree( - [predicate = std::forward(predicate)]( - IntegerType key, const ValueType& value) { - return predicate(Codec::decode(key), value); + [predicate = fwd_capture(std::forward(predicate))]( + IntegerType key, const ValueType& value) mutable { + return predicate.get()(Codec::decode(key), value); }, m_tree); } diff --git a/include/sparta/PatriciaTreeHashMap.h b/include/sparta/PatriciaTreeHashMap.h index e96e167..4049916 100644 --- a/include/sparta/PatriciaTreeHashMap.h +++ b/include/sparta/PatriciaTreeHashMap.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace sparta { namespace pthm_impl { @@ -163,12 +164,9 @@ class PatriciaTreeHashMap final template // void(mapped_type*) PatriciaTreeHashMap& update(Operation&& operation, const Key& key) { m_tree.update( - [operation = std::forward(operation), - &key](FlatMapT flat_map) -> FlatMapT { - // We should be using `std::forward` here but we would get a compiler - // error, see: - // https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html - flat_map.update(operation, key); + [operation = fwd_capture(std::forward(operation)), + &key](FlatMapT flat_map) mutable -> FlatMapT { + flat_map.update(std::forward(operation.get()), key); return flat_map; }, KeyHash()(key)); @@ -177,11 +175,11 @@ class PatriciaTreeHashMap final template // void(mapped_type*) void transform(MappingFunction&& f) { - m_tree.transform( - [f = std::forward(f)](FlatMapT flat_map) -> FlatMapT { - flat_map.transform(f); - return flat_map; - }); + m_tree.transform([f = fwd_capture(std::forward(f))]( + FlatMapT flat_map) mutable -> FlatMapT { + flat_map.transform(f.get()); + return flat_map; + }); } /* @@ -190,9 +188,9 @@ class PatriciaTreeHashMap final */ template // void(const value_type&) void visit(Visitor&& visitor) const { - m_tree.visit([visitor = std::forward(visitor)]( - const std::pair& binding) { - binding.second.visit(visitor); + m_tree.visit([visitor = fwd_capture(std::forward(visitor))]( + const std::pair& binding) mutable { + binding.second.visit(visitor.get()); }); } @@ -207,9 +205,10 @@ class PatriciaTreeHashMap final } template // bool(const Key&, const mapped_type&) - PatriciaTreeHashMap& filter(const Predicate& predicate) { - m_tree.transform([&predicate](FlatMapT flat_map) -> FlatMapT { - flat_map.filter(predicate); + PatriciaTreeHashMap& filter(Predicate&& predicate) { + m_tree.transform([predicate = fwd_capture(std::forward( + predicate))](FlatMapT flat_map) mutable -> FlatMapT { + flat_map.filter(predicate.get()); return flat_map; }); return *this; @@ -222,11 +221,12 @@ class PatriciaTreeHashMap final // Requires CombiningFunction to coerce to // std::function template - PatriciaTreeHashMap& union_with(const CombiningFunction& combine, + PatriciaTreeHashMap& union_with(CombiningFunction&& combine, const PatriciaTreeHashMap& other) { m_tree.union_with( - [&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT { - left.union_with(combine, right); + [combine = fwd_capture(std::forward(combine))]( + FlatMapT left, const FlatMapT& right) mutable -> FlatMapT { + left.union_with(combine.get(), right); return left; }, other.m_tree); @@ -236,11 +236,12 @@ class PatriciaTreeHashMap final // Requires CombiningFunction to coerce to // std::function template - PatriciaTreeHashMap& intersection_with(const CombiningFunction& combine, + PatriciaTreeHashMap& intersection_with(CombiningFunction&& combine, const PatriciaTreeHashMap& other) { m_tree.intersection_with( - [&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT { - left.intersection_with(combine, right); + [combine = fwd_capture(std::forward(combine))]( + FlatMapT left, const FlatMapT& right) mutable -> FlatMapT { + left.intersection_with(combine.get(), right); return left; }, other.m_tree); @@ -249,11 +250,12 @@ class PatriciaTreeHashMap final // Requires that `combine(bottom, ...) = bottom`. template - PatriciaTreeHashMap& difference_with(const CombiningFunction& combine, + PatriciaTreeHashMap& difference_with(CombiningFunction&& combine, const PatriciaTreeHashMap& other) { m_tree.difference_with( - [&combine](FlatMapT left, const FlatMapT& right) -> FlatMapT { - left.difference_with(combine, right); + [combine = fwd_capture(std::forward(combine))]( + FlatMapT left, const FlatMapT& right) mutable -> FlatMapT { + left.difference_with(combine.get(), right); return left; }, other.m_tree); diff --git a/include/sparta/PatriciaTreeMap.h b/include/sparta/PatriciaTreeMap.h index a89c73f..9630a63 100644 --- a/include/sparta/PatriciaTreeMap.h +++ b/include/sparta/PatriciaTreeMap.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace sparta { namespace ptm_impl { @@ -237,10 +238,11 @@ class PatriciaTreeMap final // This wraps the given function to apply these transformations. template inline static auto apply_leafs(Func&& func) { - return [func = std::forward(func)](const auto&... leaf_ptrs) { + return [func = fwd_capture(std::forward(func))]( + const auto&... leaf_ptrs) mutable { auto default_value = ValueInterface::default_value(); auto return_value = - func((leaf_ptrs ? leaf_ptrs->value() : default_value)...); + func.get()((leaf_ptrs ? leaf_ptrs->value() : default_value)...); return keep_if_non_default(std::move(return_value)); }; diff --git a/include/sparta/PerfectForwardCapture.h b/include/sparta/PerfectForwardCapture.h new file mode 100644 index 0000000..5a17a95 --- /dev/null +++ b/include/sparta/PerfectForwardCapture.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace sparta { +namespace fwd_impl { + +template +class by_value { + public: + template + explicit by_value(U&& x) : m_x{std::forward(x)} {} + + auto& get() & { return m_x; } + const auto& get() const& { return m_x; } + auto get() && { return std::move(m_x); } + + private: + T m_x; +}; + +template +class by_ref { + public: + explicit by_ref(T& x) : m_x{x} {} + + auto& get() & { return m_x.get(); } + const auto& get() const& { return m_x.get(); } + auto get() && { return std::move(m_x.get()); } + + private: + std::reference_wrapper m_x; +}; + +} // namespace fwd_impl + +// Unspecialized version: stores a `T` instance by value. +template +struct fwd_capture_wrapper : fwd_impl::by_value { + using fwd_impl::by_value::by_value; +}; + +// Specialized version: stores a `T` reference. +template +struct fwd_capture_wrapper : fwd_impl::by_ref { + using fwd_impl::by_ref::by_ref; +}; + +/** + * Utility function to perfectly forward a universal reference captured in a + * lambda. See + * https://vittorioromeo.info/index/blog/capturing_perfectly_forwarded_objects_in_lambdas.html + * + * Example usage: + * ``` + * template + * auto f(T x) { + * return [x = fwd_capture(std::forward(x))](...) mutable { + * auto v = x.get(); + * // ... + * }; + * } + * ``` + */ +template +auto fwd_capture(T&& x) { + return fwd_capture_wrapper(std::forward(x)); +} + +} // namespace sparta diff --git a/test/PatriciaTreeHashMapTest.cpp b/test/PatriciaTreeHashMapTest.cpp index 2eb6b46..c03e5bd 100644 --- a/test/PatriciaTreeHashMapTest.cpp +++ b/test/PatriciaTreeHashMapTest.cpp @@ -181,3 +181,50 @@ TEST(PatriciaTreeHashMapTest, difference) { create_pth_map({{2, 1}, {4, 1}, {6, 1}})), create_pth_map({{1, 3}, {3, 3}, {5, 3}})); } + +TEST(PatriciaTreeHashMapTest, movableOperators) { + pth_map p = create_pth_map({{0, 1}, {1, 2}}); + + // lambda passed by rvalue reference, holding a non-copyable value. + auto movable = std::make_unique(3); + p.update( + [movable = std::move(movable)](uint32_t* value) mutable { + auto tmp = std::move(movable); + *value = *value + *tmp; + }, + 0); + EXPECT_EQ(p.at(0), 4); + + movable = std::make_unique(4); + auto updater = [movable = std::move(movable)](uint32_t* value) mutable { + auto tmp = std::move(movable); + *value = *value + *tmp; + }; + p.update(updater, 0); + EXPECT_EQ(p.at(0), 8); + + // lambda passed by rvalue reference, holding a non-copyable value. + movable = std::make_unique(10); + p.transform([movable = std::move(movable)](uint32_t* value) mutable { + auto tmp = std::move(movable); + (*tmp)++; + auto new_value = *tmp; + movable = std::move(tmp); + *value = new_value; + }); + EXPECT_EQ(p.at(0), 11); + EXPECT_EQ(p.at(1), 12); + + // lambda passed by lvalue reference, holding a non-copyable value. + movable = std::make_unique(20); + auto transformer = [movable = std::move(movable)](uint32_t* value) mutable { + auto tmp = std::move(movable); + (*tmp)++; + auto new_value = *tmp; + movable = std::move(tmp); + *value = new_value; + }; + p.transform(transformer); + EXPECT_EQ(p.at(0), 21); + EXPECT_EQ(p.at(1), 22); +} diff --git a/test/PatriciaTreeMapTest.cpp b/test/PatriciaTreeMapTest.cpp index da0a0a2..997edd3 100644 --- a/test/PatriciaTreeMapTest.cpp +++ b/test/PatriciaTreeMapTest.cpp @@ -210,3 +210,52 @@ TEST(PatriciaTreeMapTest, difference) { create_pt_map({{2, 1}, {4, 1}, {6, 1}})), create_pt_map({{1, 3}, {3, 3}, {5, 3}})); } + +TEST(PatriciaTreeMapTest, movableOperators) { + pt_map p = create_pt_map({{0, 1}, {1, 2}}); + + // lambda passed by rvalue reference, holding a non-copyable value. + auto movable = std::make_unique(3); + p.update( + [movable = std::move(movable)](uint32_t value) mutable { + auto tmp = std::move(movable); + return value + *tmp; + }, + 0); + EXPECT_EQ(p.at(0), 4); + + // lambda passed by lvalue reference, holding a non-copyable value. + movable = std::make_unique(4); + auto updater = [movable = std::move(movable)](uint32_t value) mutable { + auto tmp = std::move(movable); + return value + *tmp; + }; + p.update(updater, 0); + EXPECT_EQ(p.at(0), 8); + + // lambda passed by rvalue reference, holding a non-copyable value. + movable = std::make_unique(10); + p.transform([movable = std::move(movable)](uint32_t /* value */) mutable { + auto tmp = std::move(movable); + (*tmp)++; + auto new_value = *tmp; + movable = std::move(tmp); + return new_value; + }); + EXPECT_EQ(p.at(0), 11); + EXPECT_EQ(p.at(1), 12); + + // lambda passed by lvalue reference, holding a non-copyable value. + movable = std::make_unique(20); + auto transformer = [movable = + std::move(movable)](uint32_t /* value */) mutable { + auto tmp = std::move(movable); + (*tmp)++; + auto new_value = *tmp; + movable = std::move(tmp); + return new_value; + }; + p.transform(transformer); + EXPECT_EQ(p.at(0), 21); + EXPECT_EQ(p.at(1), 22); +}