Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(trie): parallel rlp node updates in sparse trie #13251

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/trie/sparse/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ alloy-primitives.workspace = true
alloy-rlp.workspace = true

# misc
rayon.workspace = true
smallvec = { workspace = true, features = ["const_new"] }
thiserror.workspace = true

Expand Down
4 changes: 2 additions & 2 deletions crates/trie/sparse/src/blinded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use reth_trie_common::Nibbles;
/// Factory for instantiating blinded node providers.
pub trait BlindedProviderFactory {
/// Type capable of fetching blinded account nodes.
type AccountNodeProvider: BlindedProvider;
type AccountNodeProvider: BlindedProvider + Send + Sync;
/// Type capable of fetching blinded storage nodes.
type StorageNodeProvider: BlindedProvider;
type StorageNodeProvider: BlindedProvider + Send + Sync;

/// Returns blinded account node provider.
fn account_node_provider(&self) -> Self::AccountNodeProvider;
Expand Down
12 changes: 6 additions & 6 deletions crates/trie/sparse/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,6 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
Ok(())
}

/// Calculates the hashes of the nodes below the provided level.
pub fn calculate_below_level(&mut self, level: usize) {
self.state.calculate_below_level(level);
}

/// Returns storage sparse trie root if the trie has been revealed.
pub fn storage_root(&mut self, account: B256) -> Option<B256> {
self.storages.get_mut(&account).and_then(|trie| trie.root())
Expand Down Expand Up @@ -346,7 +341,7 @@ impl<F: BlindedProviderFactory> SparseStateTrie<F> {
}
impl<F> SparseStateTrie<F>
where
F: BlindedProviderFactory,
F: BlindedProviderFactory + Send + Sync,
SparseTrieError: From<<F::AccountNodeProvider as BlindedProvider>::Error>
+ From<<F::StorageNodeProvider as BlindedProvider>::Error>,
{
Expand Down Expand Up @@ -423,6 +418,11 @@ where
storage_trie.remove_leaf(slot)?;
Ok(())
}

/// Calculates the hashes of the nodes below the provided level.
pub fn calculate_below_level(&mut self, level: usize) {
self.state.calculate_below_level(level);
}
}

#[cfg(test)]
Expand Down
160 changes: 113 additions & 47 deletions crates/trie/sparse/src/trie.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::blinded::{BlindedProvider, DefaultBlindedProvider};
use alloy_primitives::{
hex, keccak256,
keccak256,
map::{Entry, HashMap, HashSet},
B256,
};
use alloy_rlp::Decodable;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind, SparseTrieResult};
use reth_tracing::tracing::trace;
use reth_trie_common::{
Expand All @@ -13,7 +14,7 @@ use reth_trie_common::{
TrieNode, CHILD_INDEX_RANGE, EMPTY_ROOT_HASH,
};
use smallvec::SmallVec;
use std::{borrow::Cow, fmt};
use std::{borrow::Cow, fmt, sync::mpsc};

/// Inner representation of the sparse trie.
/// Sparse trie is blind by default until nodes are revealed.
Expand Down Expand Up @@ -115,16 +116,11 @@ impl<P> SparseTrie<P> {
pub fn root(&mut self) -> Option<B256> {
Some(self.as_revealed_mut()?.root())
}

/// Calculates the hashes of the nodes below the provided level.
pub fn calculate_below_level(&mut self, level: usize) {
self.as_revealed_mut().unwrap().update_rlp_node_level(level);
}
}

impl<P> SparseTrie<P>
where
P: BlindedProvider,
P: BlindedProvider + Send + Sync,
SparseTrieError: From<P::Error>,
{
/// Update the leaf node.
Expand All @@ -140,6 +136,11 @@ where
revealed.remove_leaf(path)?;
Ok(())
}

/// Calculates the hashes of the nodes below the provided level.
pub fn calculate_below_level(&mut self, level: usize) {
self.as_revealed_mut().unwrap().update_rlp_node_level(level);
}
}

/// The representation of revealed sparse trie.
Expand All @@ -164,8 +165,6 @@ pub struct RevealedSparseTrie<P = DefaultBlindedProvider> {
prefix_set: PrefixSetMut,
/// Retained trie updates.
updates: Option<SparseTrieUpdates>,
/// Reusable buffer for RLP encoding of nodes.
rlp_buf: Vec<u8>,
}

impl<P> fmt::Debug for RevealedSparseTrie<P> {
Expand All @@ -176,7 +175,6 @@ impl<P> fmt::Debug for RevealedSparseTrie<P> {
.field("values", &self.values)
.field("prefix_set", &self.prefix_set)
.field("updates", &self.updates)
.field("rlp_buf", &hex::encode(&self.rlp_buf))
.finish_non_exhaustive()
}
}
Expand All @@ -190,7 +188,6 @@ impl Default for RevealedSparseTrie {
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
updates: None,
rlp_buf: Vec::new(),
}
}
}
Expand All @@ -208,7 +205,6 @@ impl RevealedSparseTrie {
branch_node_hash_masks: HashMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
rlp_buf: Vec::new(),
updates: None,
}
.with_updates(retain_updates);
Expand All @@ -231,7 +227,6 @@ impl<P> RevealedSparseTrie<P> {
branch_node_hash_masks: HashMap::default(),
values: HashMap::default(),
prefix_set: PrefixSetMut::default(),
rlp_buf: Vec::new(),
updates: None,
}
.with_updates(retain_updates);
Expand All @@ -248,7 +243,6 @@ impl<P> RevealedSparseTrie<P> {
values: self.values,
prefix_set: self.prefix_set,
updates: self.updates,
rlp_buf: self.rlp_buf,
}
}

Expand Down Expand Up @@ -523,19 +517,6 @@ impl<P> RevealedSparseTrie<P> {
}
}

/// Update hashes of the nodes that are located at a level deeper than or equal to the provided
/// depth. Root node has a level of 0.
pub fn update_rlp_node_level(&mut self, depth: usize) {
let mut prefix_set = self.prefix_set.clone().freeze();
let mut buffers = RlpNodeBuffers::default();

let targets = self.get_changed_nodes_at_depth(&mut prefix_set, depth);
for target in targets {
buffers.path_stack.push((target, Some(true)));
self.rlp_node(&mut prefix_set, &mut buffers);
}
}

/// Returns a list of paths to the nodes that were changed according to the prefix set and are
/// located at the provided depth when counting from the root node. If there's a leaf at a
/// depth less than the provided depth, it will be included in the result.
Expand Down Expand Up @@ -590,18 +571,29 @@ impl<P> RevealedSparseTrie<P> {

fn rlp_node_allocate(&mut self, path: Nibbles, prefix_set: &mut PrefixSet) -> RlpNode {
let mut buffers = RlpNodeBuffers::new_with_path(path);
self.rlp_node(prefix_set, &mut buffers)
let (root, updates) = self.rlp_node(prefix_set, &mut buffers, &mut Vec::new());
self.apply_rlp_node_updates(updates);
root
}

fn rlp_node(&mut self, prefix_set: &mut PrefixSet, buffers: &mut RlpNodeBuffers) -> RlpNode {
fn rlp_node(
&self,
prefix_set: &mut PrefixSet,
buffers: &mut RlpNodeBuffers,
rlp_buf: &mut Vec<u8>,
) -> (RlpNode, RlpNodeUpdates) {
let mut rlp_node_updates = RlpNodeUpdates::default();

'main: while let Some((path, mut is_in_prefix_set)) = buffers.path_stack.pop() {
// Check if the path is in the prefix set.
// First, check the cached value. If it's `None`, then check the prefix set, and update
// the cached value.
let mut prefix_set_contains =
|path: &Nibbles| *is_in_prefix_set.get_or_insert_with(|| prefix_set.contains(path));

let (rlp_node, calculated, node_type) = match self.nodes.get_mut(&path).unwrap() {
let mut rlp_node_update = RlpNodeUpdate::default();

let (rlp_node, calculated, node_type) = match self.nodes.get(&path).unwrap() {
SparseNode::Empty => {
(RlpNode::word_rlp(&EMPTY_ROOT_HASH), false, SparseNodeType::Empty)
}
Expand All @@ -613,9 +605,9 @@ impl<P> RevealedSparseTrie<P> {
(RlpNode::word_rlp(&hash), false, SparseNodeType::Leaf)
} else {
let value = self.values.get(&path).unwrap();
self.rlp_buf.clear();
let rlp_node = LeafNodeRef { key, value }.rlp(&mut self.rlp_buf);
*hash = rlp_node.as_hash();
rlp_buf.clear();
let rlp_node = LeafNodeRef { key, value }.rlp(rlp_buf);
rlp_node_update.hash = rlp_node.as_hash();
(rlp_node, true, SparseNodeType::Leaf)
}
}
Expand All @@ -630,9 +622,9 @@ impl<P> RevealedSparseTrie<P> {
)
} else if buffers.rlp_node_stack.last().is_some_and(|e| e.0 == child_path) {
let (_, child, _, node_type) = buffers.rlp_node_stack.pop().unwrap();
self.rlp_buf.clear();
let rlp_node = ExtensionNodeRef::new(key, &child).rlp(&mut self.rlp_buf);
*hash = rlp_node.as_hash();
rlp_buf.clear();
let rlp_node = ExtensionNodeRef::new(key, &child).rlp(rlp_buf);
rlp_node_update.hash = rlp_node.as_hash();

(
rlp_node,
Expand Down Expand Up @@ -746,17 +738,14 @@ impl<P> RevealedSparseTrie<P> {
}
}

self.rlp_buf.clear();
rlp_buf.clear();
let branch_node_ref =
BranchNodeRef::new(&buffers.branch_value_stack_buf, *state_mask);
let rlp_node = branch_node_ref.rlp(&mut self.rlp_buf);
*hash = rlp_node.as_hash();
let rlp_node = branch_node_ref.rlp(rlp_buf);

// Save a branch node update only if it's not a root node, and we need to
// persist updates.
let store_in_db_trie_value = if let Some(updates) =
self.updates.as_mut().filter(|_| retain_updates && !path.is_empty())
{
let store_in_db_trie_value = if retain_updates && !path.is_empty() {
let mut tree_mask_values = tree_mask_values.into_iter().rev();
let mut hash_mask_values = hash_mask_values.into_iter().rev();
let mut tree_mask = TrieMask::default();
Expand Down Expand Up @@ -784,14 +773,16 @@ impl<P> RevealedSparseTrie<P> {
hashes,
hash.filter(|_| path.len() == 0),
);
updates.updated_nodes.insert(path.clone(), branch_node);
rlp_node_update.branch_node = Some(branch_node);
}

store_in_db_trie
} else {
false
};
*store_in_db_trie = Some(store_in_db_trie_value);

rlp_node_update.hash = rlp_node.as_hash();
rlp_node_update.store_in_db_trie = Some(store_in_db_trie_value);

(
rlp_node,
Expand All @@ -800,17 +791,45 @@ impl<P> RevealedSparseTrie<P> {
)
}
};

if !rlp_node_update.is_empty() {
rlp_node_updates.insert(path.clone(), rlp_node_update);
}

buffers.rlp_node_stack.push((path, rlp_node, calculated, node_type));
}

debug_assert_eq!(buffers.rlp_node_stack.len(), 1);
buffers.rlp_node_stack.pop().unwrap().1
(buffers.rlp_node_stack.pop().unwrap().1, rlp_node_updates)
}

fn apply_rlp_node_updates(&mut self, rlp_node_updates: RlpNodeUpdates) {
for (path, update) in rlp_node_updates {
if let Some(node) = self.nodes.get_mut(&path) {
match node {
SparseNode::Leaf { hash, .. } | SparseNode::Extension { hash, .. } => {
*hash = update.hash
}
SparseNode::Branch { hash, store_in_db_trie, .. } => {
*hash = update.hash;
*store_in_db_trie = update.store_in_db_trie
}
SparseNode::Empty | SparseNode::Hash(_) => unreachable!(),
}
}

if let Some(branch_node) = update.branch_node {
if let Some(updates) = self.updates.as_mut() {
updates.updated_nodes.insert(path, branch_node);
}
}
}
}
}

impl<P> RevealedSparseTrie<P>
where
P: BlindedProvider,
P: BlindedProvider + Send + Sync,
SparseTrieError: From<P::Error>,
{
/// Update the leaf node with provided value.
Expand Down Expand Up @@ -1110,6 +1129,53 @@ where

Ok(())
}

/// Update hashes of the nodes that are located at a level deeper than or equal to the provided
/// depth. Root node has a level of 0.
pub fn update_rlp_node_level(&mut self, depth: usize) {
let mut prefix_set = self.prefix_set.clone().freeze();

let targets = self.get_changed_nodes_at_depth(&mut prefix_set, depth);
let (tx, rx) = mpsc::channel();
targets
.into_par_iter()
.map_init(
|| (prefix_set.clone(), RlpNodeBuffers::default(), Vec::new()),
|(prefix_set, buffers, rlp_node), target| {
buffers.path_stack.push((target, Some(true)));
let (_, updates) = self.rlp_node(prefix_set, buffers, rlp_node);
updates
},
)
.for_each_init(
|| tx.clone(),
|tx, updates| {
tx.send(updates).unwrap();
},
);
drop(tx);

for updates in rx {
self.apply_rlp_node_updates(updates);
}
}
}

/// Updates that [`RevealedSparseTrie::rlp_node`] produced.
type RlpNodeUpdates = HashMap<Nibbles, RlpNodeUpdate>;

/// An update that [`RevealedSparseTrie::rlp_node`] produced after processing one node.
#[derive(Debug, Default)]
struct RlpNodeUpdate {
hash: Option<B256>,
store_in_db_trie: Option<bool>,
branch_node: Option<BranchNodeCompact>,
}

impl RlpNodeUpdate {
const fn is_empty(&self) -> bool {
self.hash.is_none() && self.store_in_db_trie.is_none() && self.branch_node.is_none()
}
}

/// Enum representing sparse trie node type.
Expand Down
6 changes: 3 additions & 3 deletions crates/trie/trie/src/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ impl<F> WitnessBlindedProviderFactory<F> {

impl<F> BlindedProviderFactory for WitnessBlindedProviderFactory<F>
where
F: BlindedProviderFactory,
F::AccountNodeProvider: BlindedProvider<Error = SparseTrieError>,
F::StorageNodeProvider: BlindedProvider<Error = SparseTrieError>,
F: BlindedProviderFactory + Send + Sync,
F::AccountNodeProvider: BlindedProvider<Error = SparseTrieError> + Send + Sync,
F::StorageNodeProvider: BlindedProvider<Error = SparseTrieError> + Send + Sync,
{
type AccountNodeProvider = WitnessBlindedProvider<F::AccountNodeProvider>;
type StorageNodeProvider = WitnessBlindedProvider<F::StorageNodeProvider>;
Expand Down
Loading