-
Notifications
You must be signed in to change notification settings - Fork 115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparsify Spartan #403
Sparsify Spartan #403
Conversation
let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple( | ||
uni_constraint_evals, | ||
unsafe_allocate_sparse_zero_vec, | ||
self.uniform_repeat, // Capacity overhead for offset_eq constraints. | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we compute uni_constraint_evals
for A, B, and C separately, each using a flat_map
so that we avoid this par_flatten_triple
?
Only downside I can see is that we wouldn't be able to reuse the dense_output_buffer
between A, B and C. But I think that the flat_map
approach is cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely agree this is unsanitary.
Unfortunately the plain flat_map
approach adds about 60% overhead. 1.9s -> 3.3s on a 40s trace. I believe this is because it doesn't know the shape of the underlying and the chunks are very non-uniform sizes. par_flatten_triple
is optimized to handle this well.
Ideally we'd know which sparse_index
to write each non-zero LC eval to in advance, but the nature of sparsity is that we don't know.
To get parallelism at the evaluate_lc_chunk
level we compute self.uniform_num_rows
LC computations in parallel. If we did this sparsely it would require a .par_iter().map(...).collect()
during evaluate_lc_chunk
, adding a painful allocation / copy.
The one approach that I have not tried is to chunk the constraints into NUM_THREADS
-sized chunks, iterate over each chunk of constraints in parallel, then write directly to a sparse Vec<(F, usize)>
directly for each chunk. We would then have to collect a Vec<Vec<(F, usize)>> -> Vec<Vec<(F, usize)>>
.
In practice num_constraints ~= 20
, thus the chunk size would be ~= 2. The expense of the LC computation is (very) non-uniform across constraints. We would lose parallelism across the LC (in order to append to a mutable vec) then get worse work stealing across constraints. We cannot account for this in the standard way by increasing the number of chunks as the chunk size would approach 1. There is theoretically a way to account for this by assigning a "computation weight" to each constraint / LC. I have not attempted this and imagine the code gets substantially grosser, but I imagine it could save 1% e2e.
Attaching the code for the flat_map version below so we can try it again later easily. Note the dense_output_buffers are over allocated in this version but it's not the driver of the performance decline.
let evaluate_lc_chunk = |lc: &LC<I>,
dense_output_buffer: &mut [F],
constraint_index: usize|
-> Vec<(F, usize)> {
if !lc.terms().is_empty() {
let inputs = batch_inputs(lc);
lc.evaluate_batch_mut(&inputs, dense_output_buffer);
// Take only the non-zero elements and represent them as sparse tuples (eval, dense_index)
let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot
for (local_index, item) in dense_output_buffer.iter().enumerate() {
if !item.is_zero() {
let global_index = constraint_index * self.uniform_repeat + local_index;
sparse.push((*item, global_index));
}
}
sparse
} else {
vec![]
}
};
// uniform_constraints: Xz[0..uniform_constraint_rows]
let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals");
let _enter = span.enter();
let mut az_sparse: Vec<(F, usize)> = self
.uniform_builder
.constraints
.par_iter()
.enumerate()
.flat_map(|(constraint_index, constraint)| {
let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);
evaluate_lc_chunk(&constraint.a, &mut dense_output_buffer, constraint_index)
})
.collect();
let mut bz_sparse: Vec<(F, usize)> = self
.uniform_builder
.constraints
.par_iter()
.enumerate()
.flat_map(|(constraint_index, constraint)| {
let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);
evaluate_lc_chunk(&constraint.b, &mut dense_output_buffer, constraint_index)
})
.collect();
let cz_sparse: Vec<(F, usize)> = self
.uniform_builder
.constraints
.par_iter()
.enumerate()
.flat_map(|(constraint_index, constraint)| {
let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);
evaluate_lc_chunk(&constraint.c, &mut dense_output_buffer, constraint_index)
})
.collect();
drop(_enter);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another idea: We could potentially nest flat_maps all the way up, but I believe this would have to be allocating / copying under the hood in a less controlled manner.
evaluate_lc_chunk(..) -> FlatMap<...>
-> constraints.par_iter().flat_map(|| evaluate_lc_chunk)
jolt-core/src/r1cs/builder.rs
Outdated
// Sparsify: take only the non-zero elements | ||
for (local_index, (az, bz)) in dense_az_bz.iter().enumerate() { | ||
let global_index = uniform_constraint_rows + local_index; | ||
if !az.is_zero() { | ||
az_sparse.push((*az, global_index)); | ||
} | ||
if !bz.is_zero() { | ||
bz_sparse.push((*bz, global_index)); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can combine this for loop with the dense_az_bz
computation above by instead doing a par_extend
of az_sparse
and bz_sparse
. A bit cleaner and more memory-efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Played around with this a bit using .par_iter().flat_map() -> .par_extend(FlatMap<_>)
and it seemed to increase costs about 3x. .iter().flat_map() -> extend(FlatMap<_>)
reached 1.25x parity with the existing method. These loops are likely memory limited given the ops are
I've updated to inline the sparsity updates at slight cost but I think solves the cleanliness.
This is not a cost driver from a RAM or CPU time perspective given there's only a handful of these constraints (1 for now, 3 in the future). Around 0.075% e2e.
} | ||
|
||
#[tracing::instrument(skip_all)] | ||
pub fn bound_poly_var_bot(&mut self, r: &F) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when do we use this vs the parallel version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use this today outside of tests.
I left it as I imagine we might in the future and it serves as documentation for the messy parallel version.
pub fn has_next(&self) -> bool { | ||
self.dense_index < self.end_index | ||
} | ||
|
||
pub fn next_pairs(&mut self) -> (usize, F, F, F, F, F, F) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we instead implement the standard Iterator
trait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to avoid using the return type Option<Self::Item>
as we have enough conditional / branching logic in the hot loop as is. But I'm not certain it would cause a slowdown. next_pairs()
also seemed clearer to me. Given the obvious thing for the SparseTripleIterator.next()
to return would be 3 F
elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one last nit, but otherwise looks good!
Makes the first Spartan sumcheck sparse (doesn't store / compute 0 terms). Reduces peak memory usage 50% allowing programs of 2x length for a fixed size of RAM. Accelerates first sumcheck 25% for around 2% e2e savings.