Skip to content

Commit

Permalink
Merge branch 'main' into feat/integration_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
PatStiles authored Jul 19, 2024
2 parents c66784b + 89c3815 commit 093bccc
Show file tree
Hide file tree
Showing 32 changed files with 1,604 additions and 380 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,15 @@ clap = { version = "4.5.4", features = ["derive"] }
eyre = "0.6.12"
rand = "0.8.5"
sysinfo = "0.30.8"
syn = { version = "1.0.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.68"
rmp-serde = "1.3.0"
toml_edit = "0.22.14"

jolt-sdk = { path = "./jolt-sdk" }
jolt-core = { path = "./jolt-core" }
common = { path = "./common" }

[profile.test]
opt-level = 3
Expand Down
3 changes: 3 additions & 0 deletions book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@
- [Roadmap](./tasks.md)
- [Optimizations](./future/opts.md)
- [Zero Knowledge](./future/zk.md)
- [Groth16 Recursion](./future/groth-16.md)
- [Precompiles](./future/precompiles.md)
- [Continuations](./future/continuations.md)
10 changes: 10 additions & 0 deletions book/src/future/continuations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Continuations via Chunking
Today Jolt is a monolithic SNARK. RISC-V traces cannot be broken up, they must be proven monolithically or not at all. As a result, Jolt has a fixed maximum trace length that can be proved which is a function of the available RAM on the prover machine. Long term we'd like to solve this by implementing a streaming version of Jolt's prover such that prover RAM usage is tunable with minimal performance loss. *TODO(sragss): Streaming prover link*

Short term we're going to solve this via monolithic chunking. The plan: Take a trace of length $N$ split it into $M$ chunks of size $N/M$ and prove each independently. $N/M$ is a function of the max RAM available to the prover.

For the direct on-chain verifier there will be a cost linear in $M$ to verify. If we wrap this in Groth16 our cost will become constant but the Groth16 prover time will be linear in $M$. We believe this short-term trade-off is worthwhile for usability until we can implement the (more complicated) streaming algorithms.

## Specifics
A generic config parameter will be added to the `Jolt` struct called `ContinuationConfig`. At the highest level, before calling `Jolt::prove` the trace will be split into `M` chunks. `Jolt::prove` will be called on each and return `RAM_final` which can be fed into `RAM_init` during the next iteration of `Jolt::prove`. The [output zerocheck](https://jolt.a16zcrypto.com/how/read_write_memory.html#ouputs-and-panic) will only be run for the final chunk.

60 changes: 60 additions & 0 deletions book/src/future/groth-16.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Groth16 Recursion
Jolt's verifier today is expensive. We estimate 1-5 million gas to verify within the EVM. Better batching techniques on opening proofs can bring this down 5-10x, but it will remain expensive. Further, the short-term [continuations plan](https://jolt.a16zcrypto.com/future/continuations.html) causes a linear blowup depending on the count of monolithic trace chunks proven.

To solve these two issues we're aiming to add a configuration option to the Jolt prover with a post processing step, which creates a Groth16 proof of the Jolt verifier for constant proof size / cost (~280k gas on EVM) regardless of continuation chunk count or opening proof cost. This technique is industry standard.

## Strategy
The easiest way to understand the workload of the verifier circuit is to jump through the codebase starting at `vm/mod.rs Jolt::verify(...)`. Verification can be split into 4 logical modules: [instruction lookups](https://jolt.a16zcrypto.com/how/instruction_lookups.html), [read-write memory](https://jolt.a16zcrypto.com/how/read_write_memory.html), [bytecode](https://jolt.a16zcrypto.com/how/bytecode.html), [r1cs](https://jolt.a16zcrypto.com/how/r1cs_constraints.html).

Each of the modules do some combination of the following:
- [Sumcheck verification](https://jolt.a16zcrypto.com/background/sumcheck.html)
- Polynomial opening proof verification
- Multi-linear extension evaluations

After recursively verifying sumcheck, the verifier needs to compare the claimed evaluation of the sumcheck operand at a random point $r$: $S(r)$ to their own evaluation of the polynomial at $r$. Jolt does this with a combination of opening proofs over the constituent polynomials of $S$ and direct evaluations of the multi-linear extensions of those polynomials if they have sufficient structure.

## Specifics
### Polynomial opening proof verification
HyperKZG is currently the optimal commitment scheme for recursion due to the requirement of only 2-pairing operations per opening proof. Unfortunately non-native field arithmetic will always be expensive within a circuit.

There are two options:
- Sumcheck and MLE evaluations using native arithmetic, pairing operations using non-native arithmetic
- Sumcheck and MLE evaluations using non-native arithmetic, pairing operations using native arithmetic

We believe the latter is more efficient albeit unergonomic. Some of the details are worked out in this paper [here](https://eprint.iacr.org/2023/961.pdf).

### Polynomial opening proof batching
Jolt requires tens of opening proofs across all constituent polynomials in all sumchecks. If we did these independently the verifier would be prohibitively expensive. Instead we [batch](https://jolt.a16zcrypto.com/background/batched-openings.html) all opening proofs for polynomials which share an evaluation point $r$.

### verify_instruction_lookups
Instruction lookups does two sumchecks described in more detail [here](https://jolt.a16zcrypto.com/how/instruction_lookups.html). The first contains some complexity. The evaluation of the MLE of each of the instructions at the point $r$ spit out by sumcheck is computed directly by the verifier. The verifier is able to do this thanks to the property from Lasso that each table is SOS (decomposable).

The `LassoSubtable` trait is implemented for all subtables. `LassoSubtable::evaluate_mle(r)` computes the MLE of each subtable. The `JoltInstruction` trait combines a series of underlying subtables. The MLEs of these subtables are combined to an instruction MLE via `JoltInstruction::combine_lookups(vals: &[F])`. Finally each of the instruction MLEs are combined into a VM-wide lookup MLE via `InstructionLookupsProof::combine_lookups(...)`.

The Groth16 verifier circuit would have to mimic this pattern. Implementing the MLE evaluation logic for each of the subtables, combination logic for each of the instructions, and combination logic to aggregate all instructions. It's possible that subtables / instructions will be added / removed in the future.

### verify_r1cs
[R1CS](https://jolt.a16zcrypto.com/how/r1cs_constraints.html) is a modified Spartan instance which runs two sumchecks and a single opening proof.

There are two difficult MLEs to evaluate:
- $\widetilde{A}, \widetilde{B}, \widetilde{C}$ – evaluations of the R1CS coefficient
- $\widetilde{z}$ – evaluation of the witness vector

> The sections below are under-described in the wiki. We'll flush these out shortly. Assume this step comes last.
For $\widetilde{A}, \widetilde{B}, \widetilde{C}$ we must leverage the uniformity to efficiently evaluate. This is under-described in the wiki, but we'll get to it ASAP.

The witness vector $z$ is comprised of all of the inputs to the R1CS circuit concatenated together in `trace_length`-sized chunks. All of these are committed independently and are checked via a batched opening proof.


# Engineering Suggestions
The Jolt codebase is rapidly undergoing improvements to reduce prover and verifier costs as well as simplify abstractions. As a result, it's recommended that each section above be built in modules that are convenient to rewire. Each part should be incrementally testable and adjustable.

A concrete example of changes to expect: we currently require 5-10 opening proofs per Jolt execution. Even for HyperKZG this requires 10-20 pairings which is prohibitively expensive. We are working on fixing this via better [batching](https://jolt.a16zcrypto.com/background/batched-openings.html).

Suggested plan of attack:
1. Circuit for HyperKZG verification
2. Circuit for batched HyperKZG verification
3. Circuit for sumcheck verification
4. Circuit for instruction/subtable MLE evaluations
5. Circuit for Spartan verification
15 changes: 15 additions & 0 deletions book/src/future/precompiles.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Precompiles
Precompiles are highly optimized SNARK gadgets which can be invoked from the high-level programming language of the VM user. These gadgets can be much more efficient for the prover than compiling down to the underlying ISA by exploiting the structure of the workload. In practice zkVMs use these for heavy cryptographic operations such as hash functions, signatures and other elliptic curve arithmetic.

By popular demand, Jolt will support these gadgets as well. The short term plan is to optimize for minimizing Jolt-core development resources rather than optimal prover speed.

Precompile support plan:
1. RV32 library wrapping syscalls of supported libraries
2. Tracer picks up syscalls, sets relevant flag bits and loads memory accordingly
3. Individual (uniform) Spartan instance for each precompile, repeated over `trace_length` steps
4. Jolt config includes which precompiles are supported (there is some non-zero prover / verifier cost to including an unused precompile)
5. Survey existing hash / elliptic curve arithmetic R1CS arithmetizations. Prioritize efficiency and audits.
6. Use [circom-scotia](https://github.com/lurk-lab/circom-scotia) to convert $A, B, C$ matrices into static files in the Jolt codebase
7. Write a converter to uniformly repeat the constraints `trace_length` steps

*TODO(sragss): How do we deal with memory and loading more than 64-bits of inputs to precompiles.*
1 change: 1 addition & 0 deletions common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ ark-serialize = { version = "0.4.2", features = ["derive"] }
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
strum_macros = "0.25.3"
syn = { version = "1.0", features = ["full"] }
61 changes: 61 additions & 0 deletions common/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use std::collections::HashMap;
use syn::{Lit, Meta, MetaNameValue, NestedMeta};

use crate::constants::{
DEFAULT_MAX_INPUT_SIZE, DEFAULT_MAX_OUTPUT_SIZE, DEFAULT_MEMORY_SIZE, DEFAULT_STACK_SIZE,
};

pub struct Attributes {
pub wasm: bool,
pub memory_size: u64,
pub stack_size: u64,
pub max_input_size: u64,
pub max_output_size: u64,
}

pub fn parse_attributes(attr: &Vec<NestedMeta>) -> Attributes {
let mut attributes = HashMap::<_, u64>::new();
let mut wasm = false;

for attr in attr {
match attr {
NestedMeta::Meta(Meta::NameValue(MetaNameValue { path, lit, .. })) => {
let value: u64 = match lit {
Lit::Int(lit) => lit.base10_parse().unwrap(),
_ => panic!("expected integer literal"),
};
let ident = &path.get_ident().expect("Expected identifier");
match ident.to_string().as_str() {
"memory_size" => attributes.insert("memory_size", value),
"stack_size" => attributes.insert("stack_size", value),
"max_input_size" => attributes.insert("max_input_size", value),
"max_output_size" => attributes.insert("max_output_size", value),
_ => panic!("invalid attribute"),
};
}
NestedMeta::Meta(Meta::Path(path)) if path.is_ident("wasm") => {
wasm = true;
}
_ => panic!("expected integer literal"),
}
}

let memory_size = *attributes
.get("memory_size")
.unwrap_or(&DEFAULT_MEMORY_SIZE);
let stack_size = *attributes.get("stack_size").unwrap_or(&DEFAULT_STACK_SIZE);
let max_input_size = *attributes
.get("max_input_size")
.unwrap_or(&DEFAULT_MAX_INPUT_SIZE);
let max_output_size = *attributes
.get("max_output_size")
.unwrap_or(&DEFAULT_MAX_OUTPUT_SIZE);

Attributes {
wasm,
memory_size,
stack_size,
max_input_size,
max_output_size,
}
}
1 change: 1 addition & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub fn to_ram_address(index: usize) -> usize {
index * constants::BYTES_PER_INSTRUCTION + constants::RAM_START_ADDRESS as usize
}

pub mod attributes;
pub mod constants;
pub mod parallel;
pub mod rv_trace;
Expand Down
13 changes: 10 additions & 3 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,12 @@ reqwest = { version = "0.12.3", features = [
dirs = "5.0.1"
eyre = "0.6.12"
indicatif = "0.17.8"
memory-stats = "1.0.0"
tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] }

common = { path = "../common" }
tracer = { path = "../tracer" }
bincode = "1.3.3"
bytemuck = "1.15.0"
tokio = { version = "1.38.0", optional = true }


[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
Expand All @@ -96,3 +95,11 @@ default = [
"rayon",
]
host = ["dep:reqwest", "dep:tokio"]

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
memory-stats = "1.0.0"
tokio = { version = "1.38.0", optional = true, features = ["rt-multi-thread"] }


[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
10 changes: 6 additions & 4 deletions jolt-core/src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use common::{
rv_trace::{JoltDevice, NUM_CIRCUIT_FLAGS},
};
use strum::EnumCount;
use tracer::ELFInstruction;
pub use tracer::ELFInstruction;

use crate::{
field::JoltField,
Expand All @@ -32,9 +32,12 @@ use crate::{
utils::thread::unsafe_allocate_zero_vec,
};

use self::{analyze::ProgramSummary, toolchain::install_toolchain};
use self::analyze::ProgramSummary;
#[cfg(not(target_arch = "wasm32"))]
use self::toolchain::install_toolchain;

pub mod analyze;
#[cfg(not(target_arch = "wasm32"))]
pub mod toolchain;

#[derive(Clone)]
Expand Down Expand Up @@ -97,6 +100,7 @@ impl Program {
#[tracing::instrument(skip_all, name = "Program::build")]
pub fn build(&mut self) {
if self.elf.is_none() {
#[cfg(not(target_arch = "wasm32"))]
install_toolchain().unwrap();
self.save_linker();

Expand Down Expand Up @@ -138,8 +142,6 @@ impl Program {
&target,
"--target",
toolchain,
"--bin",
"guest",
])
.output()
.expect("failed to build guest");
Expand Down
4 changes: 4 additions & 0 deletions jolt-core/src/host/toolchain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ use dirs::home_dir;
use eyre::{bail, eyre, Result};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
#[cfg(not(target_arch = "wasm32"))]
use tokio::runtime::Runtime;

const TOOLCHAIN_TAG: &str = include_str!("../../../.jolt.rust.toolchain-tag");
const DOWNLOAD_RETRIES: usize = 5;
const DELAY_BASE_MS: u64 = 500;

#[cfg(not(target_arch = "wasm32"))]
/// Installs the toolchain if it is not already
pub fn install_toolchain() -> Result<()> {
if !has_toolchain() {
Expand All @@ -31,6 +33,7 @@ pub fn install_toolchain() -> Result<()> {
link_toolchain()
}

#[cfg(not(target_arch = "wasm32"))]
async fn retry_times<F, T, E>(times: usize, base_ms: u64, f: F) -> Result<T>
where
F: Fn() -> E,
Expand Down Expand Up @@ -93,6 +96,7 @@ fn unpack_toolchain() -> Result<()> {
Ok(())
}

#[cfg(not(target_arch = "wasm32"))]
async fn download_toolchain(client: &Client, url: &str) -> Result<()> {
let jolt_dir = jolt_dir();
let output_path = jolt_dir.join("rust-toolchain.tar.gz");
Expand Down
4 changes: 3 additions & 1 deletion jolt-core/src/jolt/instruction/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ pub trait JoltInstructionSet:
JoltInstruction + IntoEnumIterator + EnumCount + for<'a> TryFrom<&'a ELFInstruction> + Send + Sync
{
fn enum_index(instruction: &Self) -> usize {
unsafe { *<*const _>::from(instruction).cast::<u8>() as usize }
// Discriminant: https://doc.rust-lang.org/reference/items/enumerations.html#pointer-casting
let byte = unsafe { *(instruction as *const Self as *const u8) };
byte as usize
}
}

Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/vm/instruction_lookups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ where
round_uni_poly: UniPoly<F>,
transcript: &mut ProofTranscript,
) -> F {
round_uni_poly.append_to_transcript(transcript);
round_uni_poly.compress().append_to_transcript(transcript);

transcript.challenge_scalar::<F>()
}
Expand Down
6 changes: 4 additions & 2 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ macro_rules! instruction_set {
macro_rules! subtable_enum {
($enum_name:ident, $($alias:ident: $struct:ty),+) => {
#[allow(non_camel_case_types)]
#[repr(usize)]
#[repr(u8)]
#[enum_dispatch(LassoSubtable<F>)]
#[derive(EnumCountMacro, EnumIter)]
pub enum $enum_name<F: JoltField> { $($alias($struct)),+ }
Expand All @@ -77,7 +77,9 @@ macro_rules! subtable_enum {

impl<F: JoltField> From<$enum_name<F>> for usize {
fn from(subtable: $enum_name<F>) -> usize {
unsafe { *<*const _>::from(&subtable).cast::<usize>() }
// Discriminant: https://doc.rust-lang.org/reference/items/enumerations.html#pointer-casting
let byte = unsafe { *(&subtable as *const $enum_name<F> as *const u8) };
byte as usize
}
}
impl<F: JoltField> JoltSubtableSet<F> for $enum_name<F> {}
Expand Down
1 change: 1 addition & 0 deletions jolt-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#![feature(generic_const_exprs)]
#![feature(iter_next_chunk)]
#![allow(long_running_const_eval)]
#![allow(clippy::len_without_is_empty)]

#[cfg(feature = "host")]
pub mod benches;
Expand Down
4 changes: 1 addition & 3 deletions jolt-core/src/poly/commitment/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ impl<F: JoltField> CommitmentScheme for MockCommitScheme<F> {
type Proof = MockProof<F>;
type BatchedProof = MockProof<F>;

fn setup(_shapes: &[CommitShape]) -> Self::Setup {
()
}
fn setup(_shapes: &[CommitShape]) -> Self::Setup {}
fn commit(poly: &DensePolynomial<Self::Field>, _setup: &Self::Setup) -> Self::Commitment {
MockCommitment {
poly: poly.to_owned(),
Expand Down
26 changes: 25 additions & 1 deletion jolt-core/src/poly/dense_mlpoly.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::too_many_arguments)]
use crate::poly::eq_poly::EqPolynomial;
use crate::utils::thread::unsafe_allocate_zero_vec;
use crate::utils::thread::{drop_in_background_thread, unsafe_allocate_zero_vec};
use crate::utils::{self, compute_dotproduct, compute_dotproduct_low_optimized};

use crate::field::JoltField;
Expand Down Expand Up @@ -201,11 +201,35 @@ impl<F: JoltField> DensePolynomial<F> {
}
}

/// Note: does not truncate
#[tracing::instrument(skip_all)]
pub fn bound_poly_var_bot(&mut self, r: &F) {
let n = self.len() / 2;
for i in 0..n {
self.Z[i] = self.Z[2 * i] + *r * (self.Z[2 * i + 1] - self.Z[2 * i]);
}

self.num_vars -= 1;
self.len = n;
}

pub fn bound_poly_var_bot_01_optimized(&mut self, r: &F) {
let n = self.len() / 2;
let mut new_z = unsafe_allocate_zero_vec(n);
new_z.par_iter_mut().enumerate().for_each(|(i, z)| {
let m = self.Z[2 * i + 1] - self.Z[2 * i];
*z = if m.is_zero() {
self.Z[2 * i]
} else if m.is_one() {
self.Z[2 * i] + r
} else {
self.Z[2 * i] + *r * m
}
});

let old_Z = std::mem::replace(&mut self.Z, new_z);
drop_in_background_thread(old_Z);

self.num_vars -= 1;
self.len = n;
}
Expand Down
Loading

0 comments on commit 093bccc

Please sign in to comment.