Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsify Spartan #403

Merged
merged 17 commits into from
Jul 19, 2024
Merged

Sparsify Spartan #403

merged 17 commits into from
Jul 19, 2024

Conversation

sragss
Copy link
Collaborator

@sragss sragss commented Jun 24, 2024

Makes the first Spartan sumcheck sparse (doesn't store / compute 0 terms). Reduces peak memory usage 50% allowing programs of 2x length for a fixed size of RAM. Accelerates first sumcheck 25% for around 2% e2e savings.

@sragss sragss requested a review from moodlezoup June 26, 2024 18:54
@sragss sragss marked this pull request as ready for review June 26, 2024 18:55
jolt-core/src/r1cs/builder.rs Outdated Show resolved Hide resolved
jolt-core/src/poly/dense_mlpoly.rs Outdated Show resolved Hide resolved
Comment on lines +820 to +824
let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple(
uni_constraint_evals,
unsafe_allocate_sparse_zero_vec,
self.uniform_repeat, // Capacity overhead for offset_eq constraints.
);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we compute uni_constraint_evals for A, B, and C separately, each using a flat_map so that we avoid this par_flatten_triple?
Only downside I can see is that we wouldn't be able to reuse the dense_output_buffer between A, B and C. But I think that the flat_map approach is cleaner

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely agree this is unsanitary.

Unfortunately the plain flat_map approach adds about 60% overhead. 1.9s -> 3.3s on a 40s trace. I believe this is because it doesn't know the shape of the underlying and the chunks are very non-uniform sizes. par_flatten_triple is optimized to handle this well.

Ideally we'd know which sparse_index to write each non-zero LC eval to in advance, but the nature of sparsity is that we don't know.

To get parallelism at the evaluate_lc_chunk level we compute self.uniform_num_rows LC computations in parallel. If we did this sparsely it would require a .par_iter().map(...).collect() during evaluate_lc_chunk, adding a painful allocation / copy.

The one approach that I have not tried is to chunk the constraints into NUM_THREADS-sized chunks, iterate over each chunk of constraints in parallel, then write directly to a sparse Vec<(F, usize)> directly for each chunk. We would then have to collect a Vec<Vec<(F, usize)>> -> Vec<Vec<(F, usize)>>.

In practice num_constraints ~= 20, thus the chunk size would be ~= 2. The expense of the LC computation is (very) non-uniform across constraints. We would lose parallelism across the LC (in order to append to a mutable vec) then get worse work stealing across constraints. We cannot account for this in the standard way by increasing the number of chunks as the chunk size would approach 1. There is theoretically a way to account for this by assigning a "computation weight" to each constraint / LC. I have not attempted this and imagine the code gets substantially grosser, but I imagine it could save 1% e2e.

Attaching the code for the flat_map version below so we can try it again later easily. Note the dense_output_buffers are over allocated in this version but it's not the driver of the performance decline.

        let evaluate_lc_chunk = |lc: &LC<I>,
                                     dense_output_buffer: &mut [F],
                                     constraint_index: usize|
         -> Vec<(F, usize)> {
            if !lc.terms().is_empty() {
                let inputs = batch_inputs(lc);
                lc.evaluate_batch_mut(&inputs, dense_output_buffer);

                // Take only the non-zero elements and represent them as sparse tuples (eval, dense_index)
                let mut sparse = Vec::with_capacity(self.uniform_repeat); // overshoot
                for (local_index, item) in dense_output_buffer.iter().enumerate() {
                    if !item.is_zero() {
                        let global_index = constraint_index * self.uniform_repeat + local_index;
                        sparse.push((*item, global_index));
                    }
                }
                sparse
            } else {
                vec![]
            }
        };

        // uniform_constraints: Xz[0..uniform_constraint_rows]
        let span = tracing::span!(tracing::Level::DEBUG, "uniform_evals");
        let _enter = span.enter();
        let mut az_sparse: Vec<(F, usize)> = self
            .uniform_builder
            .constraints
            .par_iter()
            .enumerate()
            .flat_map(|(constraint_index, constraint)| {
                let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);

                evaluate_lc_chunk(&constraint.a, &mut dense_output_buffer, constraint_index)
            })
            .collect();
        let mut bz_sparse: Vec<(F, usize)> = self
            .uniform_builder
            .constraints
            .par_iter()
            .enumerate()
            .flat_map(|(constraint_index, constraint)| {
                let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);

                evaluate_lc_chunk(&constraint.b, &mut dense_output_buffer, constraint_index)
            })
            .collect();
        let cz_sparse: Vec<(F, usize)> = self
            .uniform_builder
            .constraints
            .par_iter()
            .enumerate()
            .flat_map(|(constraint_index, constraint)| {
                let mut dense_output_buffer = unsafe_allocate_zero_vec(self.uniform_repeat);

                evaluate_lc_chunk(&constraint.c, &mut dense_output_buffer, constraint_index)
            })
            .collect();
        drop(_enter);

Copy link
Collaborator Author

@sragss sragss Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another idea: We could potentially nest flat_maps all the way up, but I believe this would have to be allocating / copying under the hood in a less controlled manner.

evaluate_lc_chunk(..) -> FlatMap<...> -> constraints.par_iter().flat_map(|| evaluate_lc_chunk)

Comment on lines 871 to 880
// Sparsify: take only the non-zero elements
for (local_index, (az, bz)) in dense_az_bz.iter().enumerate() {
let global_index = uniform_constraint_rows + local_index;
if !az.is_zero() {
az_sparse.push((*az, global_index));
}
if !bz.is_zero() {
bz_sparse.push((*bz, global_index));
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can combine this for loop with the dense_az_bz computation above by instead doing a par_extend of az_sparse and bz_sparse. A bit cleaner and more memory-efficient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Played around with this a bit using .par_iter().flat_map() -> .par_extend(FlatMap<_>) and it seemed to increase costs about 3x. .iter().flat_map() -> extend(FlatMap<_>) reached 1.25x parity with the existing method. These loops are likely memory limited given the ops are

I've updated to inline the sparsity updates at slight cost but I think solves the cleanliness.

This is not a cost driver from a RAM or CPU time perspective given there's only a handful of these constraints (1 for now, 3 in the future). Around 0.075% e2e.

jolt-core/src/r1cs/spartan.rs Show resolved Hide resolved
jolt-core/src/r1cs/special_polys.rs Outdated Show resolved Hide resolved
}

#[tracing::instrument(skip_all)]
pub fn bound_poly_var_bot(&mut self, r: &F) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do we use this vs the parallel version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use this today outside of tests.

I left it as I imagine we might in the future and it serves as documentation for the messy parallel version.

jolt-core/src/r1cs/special_polys.rs Outdated Show resolved Hide resolved
jolt-core/src/r1cs/special_polys.rs Show resolved Hide resolved
Comment on lines +370 to +374
pub fn has_next(&self) -> bool {
self.dense_index < self.end_index
}

pub fn next_pairs(&mut self) -> (usize, F, F, F, F, F, F) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead implement the standard Iterator trait?

Copy link
Collaborator Author

@sragss sragss Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid using the return type Option<Self::Item> as we have enough conditional / branching logic in the hot loop as is. But I'm not certain it would cause a slowdown. next_pairs() also seemed clearer to me. Given the obvious thing for the SparseTripleIterator.next() to return would be 3 F elements.

Copy link
Collaborator

@moodlezoup moodlezoup left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last nit, but otherwise looks good!

jolt-core/src/subprotocols/sumcheck.rs Outdated Show resolved Hide resolved
jolt-core/src/subprotocols/sumcheck.rs Outdated Show resolved Hide resolved
@sragss sragss merged commit a0f8fbb into main Jul 19, 2024
3 checks passed
@sragss sragss deleted the sragss/sparse-spartan branch July 19, 2024 00:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants