Skip to content

Commit

Permalink
feat: Sha256 refactoring and benchmark with longer input (#6318)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Preparation for #6304

## Summary\*

Preparation for changing the message block type in `sha256.nr`:
* Added some type aliases and extra comments, rearranged some functions
* Added a new benchmark program with longer input so that we exercise
the iteration and the last partial block as well
* Running the criterion benchmarks with and without the
`--force-brillig` option, to cover what the AVM would do
* Added an option to the `stdlib-tests.rs` to pass a filter for test
names

This is purely to make it a bit easier to see what is going on and to
establish some baseline before trying to make changes.

Tried to rationalise the code a bit:
* Removed `pad_msg_block`: based on the constraints put on its results
it looked like it's forbidden from doing anything. This allows the
removal of some constraints because for example `msg_block` and
`last_block` are equal by definition. Here's just that
[diff](https://github.com/noir-lang/noir/compare/2052451..6304-sha-msg-block-size).
* Moved the verification of padding with zeroes after the input into the
`verify_block_msg_padding` function. This is only called for the last
(partially filled) block.

According to the Circuit Size report below 👇 there is a 33% reduction in
the number of ACIR opcodes in some of the SHA256 benchmarks.

### Testing

```shell
cargo test -p nargo_cli --test stdlib-tests -- run_stdlib_tests sha256
cargo test -p nargo_cli --test stdlib-props fuzz_sha256
```

### Benchmarking

```shell
cargo bench -p nargo_cli --bench criterion sha256_long
```

The baseline benchmarks on my machine as of
c600000
were as follows:
```console
❯ cargo bench -p nargo_cli --bench criterion sha256_long
...
bench_sha256_long_execute
                        time:   [1.3613 ms 1.3688 ms 1.3782 ms]
bench_sha256_long_execute_brillig
                        time:   [286.64 µs 287.67 µs 288.96 µs]
```

For some reason after merging `master` into the PR the performance is
worse in
636c9e9

```console
❯ cargo bench -p nargo_cli --bench criterion sha256_long
...
bench_sha256_long_execute
                        time:   [1.7297 ms 1.7918 ms 1.8675 ms]
                        change: [+27.365% +29.911% +32.673%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 2 outliers among 20 measurements (10.00%)
  2 (10.00%) high severe

bench_sha256_long_execute_brillig
                        time:   [354.12 µs 360.31 µs 368.45 µs]
                        change: [+22.390% +24.264% +27.161%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 1 outliers among 20 measurements (5.00%)
  1 (5.00%) high severe

```

> __Maxim__: We did just sync aztec packages which now conditionally
inlines functions (we previously used to always inline functions). This
may be the cause of some execution time increases.

Here's the
[diff](https://github.com/noir-lang/noir/compare/c600000..636c9e9)
between those commits.

## Additional Context

In a follow-up PR I'll try to change the type of `msg_block` from `[u8;
64]` to `[u32; 16]` to avoid having to call `msg_u8_to_u32`. This should
at least have the benefit of copying the array fewer times: at the
moment an array copy is made every time an item in it is written to;
with 16 items instead of 64 we get up to 4x less copies.

## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
aakoshh authored Oct 23, 2024
1 parent 801c718 commit d606491
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 183 deletions.
278 changes: 146 additions & 132 deletions noir_stdlib/src/hash/sha256.nr
Original file line number Diff line number Diff line change
Expand Up @@ -3,101 +3,63 @@ use crate::runtime::is_unconstrained;
// Implementation of SHA-256 mapping a byte array of variable length to
// 32 bytes.

// A message block is up to 64 bytes taken from the input.
global BLOCK_SIZE = 64;

// The first index in the block where the 8 byte message size will be written.
global MSG_SIZE_PTR = 56;

// Size of the message block when packed as 4-byte integer array.
global INT_BLOCK_SIZE = 16;

// Index of a byte in a 64 byte block; ie. 0..=63
type BLOCK_BYTE_PTR = u32;

// The foreign function to compress blocks works on 16 pieces of 4-byte integers, instead of 64 bytes.
type INT_BLOCK = [u32; INT_BLOCK_SIZE];

// A message block is a slice of the original message of a fixed size,
// potentially padded with zeroes.
type MSG_BLOCK = [u8; BLOCK_SIZE];

// The hash is 32 bytes.
type HASH = [u8; 32];

// The state accumulates the blocks.
// Its overall size is the same as the `HASH`.
type STATE = [u32; 8];

// Deprecated in favour of `sha256_var`
// docs:start:sha256
pub fn sha256<let N: u32>(input: [u8; N]) -> [u8; 32]
pub fn sha256<let N: u32>(input: [u8; N]) -> HASH
// docs:end:sha256
{
digest(input)
}

#[foreign(sha256_compression)]
pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {}
pub fn sha256_compression(_input: INT_BLOCK, _state: STATE) -> STATE {}

// SHA-256 hash function
#[no_predicates]
pub fn digest<let N: u32>(msg: [u8; N]) -> [u8; 32] {
pub fn digest<let N: u32>(msg: [u8; N]) -> HASH {
sha256_var(msg, N as u64)
}

// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
let mut msg32: [u32; 16] = [0; 16];

for i in 0..16 {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4 * (i + 1) + j] as Field;
}
msg32[15 - i] = msg_field as u32;
}

msg32
}

unconstrained fn build_msg_block_iter<let N: u32>(
msg: [u8; N],
message_size: u32,
msg_start: u32,
) -> ([u8; 64], u32) {
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
// We insert `BLOCK_SIZE` bytes (or up to the end of the message)
let block_input = if msg_start + BLOCK_SIZE > message_size {
if message_size < msg_start {
// This function is sometimes called with `msg_start` past the end of the message.
// In this case we return an empty block and zero pointer to signal that the result should be ignored.
0
} else {
message_size - msg_start
}
} else {
BLOCK_SIZE
};
for k in 0..block_input {
msg_block[k] = msg[msg_start + k];
}
(msg_block, block_input)
}

// Verify the block we are compressing was appropriately constructed
fn verify_msg_block<let N: u32>(
msg: [u8; N],
message_size: u32,
msg_block: [u8; 64],
msg_start: u32,
) -> u32 {
let mut msg_byte_ptr: u32 = 0; // Message byte pointer
let mut msg_end = msg_start + BLOCK_SIZE;
if msg_end > N {
msg_end = N;
}

for k in msg_start..msg_end {
if k < message_size {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
msg_byte_ptr = msg_byte_ptr + 1;
}
}

msg_byte_ptr
}

global BLOCK_SIZE = 64;

// Variable size SHA-256 hash
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> HASH {
let message_size = message_size as u32;
let num_blocks = N / BLOCK_SIZE;
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
let mut h: [u32; 8] = [
let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE];
let mut h: STATE = [
1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635,
1541459225,
]; // Intermediate hash, starting with the canonical initial value
let mut msg_byte_ptr = 0; // Pointer into msg_block
for i in 0..num_blocks {
let msg_start = BLOCK_SIZE * i;
let (new_msg_block, new_msg_byte_ptr) =
unsafe { build_msg_block_iter(msg, message_size, msg_start) };
unsafe { build_msg_block(msg, message_size, msg_start) };
if msg_start < message_size {
msg_block = new_msg_block;
}
Expand Down Expand Up @@ -126,7 +88,7 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
if modulo != 0 {
let msg_start = BLOCK_SIZE * num_blocks;
let (new_msg_block, new_msg_byte_ptr) =
unsafe { build_msg_block_iter(msg, message_size, msg_start) };
unsafe { build_msg_block(msg, message_size, msg_start) };

if msg_start < message_size {
msg_block = new_msg_block;
Expand All @@ -136,116 +98,168 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start);
if msg_start < message_size {
msg_byte_ptr = new_msg_byte_ptr;
verify_msg_block_padding(msg_block, msg_byte_ptr);
}
} else if msg_start < message_size {
msg_byte_ptr = new_msg_byte_ptr;
}
}

// If we had modulo == 0 then it means the last block was full,
// and we can reset the pointer to zero to overwrite it.
if msg_byte_ptr == BLOCK_SIZE {
msg_byte_ptr = 0;
}

// This variable is used to get around the compiler under-constrained check giving a warning.
// We want to check against a constant zero, but if it does not come from the circuit inputs
// or return values the compiler check will issue a warning.
let zero = msg_block[0] - msg_block[0];

// Pad the rest such that we have a [u32; 2] block at the end representing the length
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
// Here we rely on the fact that everything beyond the available input is set to 0.
msg_block[msg_byte_ptr] = 1 << 7;
let last_block = msg_block;
msg_byte_ptr = msg_byte_ptr + 1;

unsafe {
let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr);
msg_block = new_msg_block;
if crate::runtime::is_unconstrained() {
msg_byte_ptr = new_msg_byte_ptr;
}
// If we don't have room to write the size, compress the block and reset it.
if msg_byte_ptr > MSG_SIZE_PTR {
h = sha256_compression(msg_u8_to_u32(msg_block), h);
// `attach_len_to_msg_block` will zero out everything after the `msg_byte_ptr`.
msg_byte_ptr = 0;
}

if !crate::runtime::is_unconstrained() {
for i in 0..BLOCK_SIZE {
assert_eq(msg_block[i], last_block[i]);
}
msg_block = unsafe { attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) };

// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
// Not enough bits (64) to store length. Fill up with zeros.
for _i in 57..BLOCK_SIZE {
if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 {
assert_eq(msg_block[msg_byte_ptr], zero);
msg_byte_ptr += 1;
}
}
if !crate::runtime::is_unconstrained() {
verify_msg_len(msg_block, last_block, msg_byte_ptr, message_size);
}

if msg_byte_ptr >= 57 {
h = sha256_compression(msg_u8_to_u32(msg_block), h);
hash_final_block(msg_block, h)
}

msg_byte_ptr = 0;
// Convert 64-byte array to array of 16 u32s
fn msg_u8_to_u32(msg: MSG_BLOCK) -> INT_BLOCK {
let mut msg32: INT_BLOCK = [0; INT_BLOCK_SIZE];

for i in 0..INT_BLOCK_SIZE {
let mut msg_field: Field = 0;
for j in 0..4 {
msg_field = msg_field * 256 + msg[64 - 4 * (i + 1) + j] as Field;
}
msg32[15 - i] = msg_field as u32;
}

msg_block = unsafe { attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) };
msg32
}

if !crate::runtime::is_unconstrained() {
for i in 0..56 {
let predicate = (i < msg_byte_ptr) as u8;
let expected_byte = predicate * last_block[i];
assert_eq(msg_block[i], expected_byte);
// Take `BLOCK_SIZE` number of bytes from `msg` starting at `msg_start`.
// Returns the block and the length that has been copied rather than padded with zeroes.
unconstrained fn build_msg_block<let N: u32>(
msg: [u8; N],
message_size: u32,
msg_start: u32,
) -> (MSG_BLOCK, BLOCK_BYTE_PTR) {
let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE];
// We insert `BLOCK_SIZE` bytes (or up to the end of the message)
let block_input = if msg_start + BLOCK_SIZE > message_size {
if message_size < msg_start {
// This function is sometimes called with `msg_start` past the end of the message.
// In this case we return an empty block and zero pointer to signal that the result should be ignored.
0
} else {
message_size - msg_start
}
} else {
BLOCK_SIZE
};
for k in 0..block_input {
msg_block[k] = msg[msg_start + k];
}
(msg_block, block_input)
}

// Verify the block we are compressing was appropriately constructed by `build_msg_block`
// and matches the input data. Returns the index of the first unset item.
fn verify_msg_block<let N: u32>(
msg: [u8; N],
message_size: u32,
msg_block: MSG_BLOCK,
msg_start: u32,
) -> BLOCK_BYTE_PTR {
let mut msg_byte_ptr: u32 = 0; // Message byte pointer
let mut msg_end = msg_start + BLOCK_SIZE;
if msg_end > N {
msg_end = N;
}

// We verify the message length was inserted correctly by reversing the byte decomposition.
let len = 8 * message_size;
let mut reconstructed_len: Field = 0;
for i in 56..64 {
reconstructed_len = 256 * reconstructed_len + msg_block[i] as Field;
for k in msg_start..msg_end {
if k < message_size {
assert_eq(msg_block[msg_byte_ptr], msg[k]);
msg_byte_ptr = msg_byte_ptr + 1;
}
assert_eq(reconstructed_len, len as Field);
}

hash_final_block(msg_block, h)
msg_byte_ptr
}

unconstrained fn pad_msg_block(
mut msg_block: [u8; 64],
mut msg_byte_ptr: u32,
) -> ([u8; BLOCK_SIZE], u32) {
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
// the 1 and 0s fill up the current block, which we then compress accordingly.
if msg_byte_ptr >= 57 {
// Not enough bits (64) to store length. Fill up with zeros.
for i in msg_byte_ptr..BLOCK_SIZE {
msg_block[i] = 0;
// Verify the block we are compressing was appropriately padded with zeroes by `build_msg_block`.
// This is only relevant for the last, potentially partially filled block.
fn verify_msg_block_padding(msg_block: MSG_BLOCK, msg_byte_ptr: BLOCK_BYTE_PTR) {
// This variable is used to get around the compiler under-constrained check giving a warning.
// We want to check against a constant zero, but if it does not come from the circuit inputs
// or return values the compiler check will issue a warning.
let zero = msg_block[0] - msg_block[0];

for i in 0..BLOCK_SIZE {
if i >= msg_byte_ptr {
assert_eq(msg_block[i], zero);
}
(msg_block, BLOCK_SIZE)
} else {
(msg_block, msg_byte_ptr)
}
}

// Zero out all bytes between the end of the message and where the length is appended,
// then write the length into the last 8 bytes of the block.
unconstrained fn attach_len_to_msg_block(
mut msg_block: [u8; BLOCK_SIZE],
msg_byte_ptr: u32,
mut msg_block: MSG_BLOCK,
msg_byte_ptr: BLOCK_BYTE_PTR,
message_size: u32,
) -> [u8; BLOCK_SIZE] {
) -> MSG_BLOCK {
// We assume that `msg_byte_ptr` is less than 57 because if not then it is reset to zero before calling this function.
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
for i in msg_byte_ptr..56 {
for i in msg_byte_ptr..MSG_SIZE_PTR {
msg_block[i] = 0;
}

let len = 8 * message_size;
let len_bytes: [u8; 8] = (len as Field).to_be_bytes();
for i in 0..8 {
msg_block[56 + i] = len_bytes[i];
msg_block[MSG_SIZE_PTR + i] = len_bytes[i];
}
msg_block
}

fn hash_final_block(msg_block: [u8; BLOCK_SIZE], mut state: [u32; 8]) -> [u8; 32] {
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes
// Verify that the message length was correctly written by `attach_len_to_msg_block`.
fn verify_msg_len(
msg_block: MSG_BLOCK,
last_block: MSG_BLOCK,
msg_byte_ptr: BLOCK_BYTE_PTR,
message_size: u32,
) {
for i in 0..MSG_SIZE_PTR {
let predicate = (i < msg_byte_ptr) as u8;
let expected_byte = predicate * last_block[i];
assert_eq(msg_block[i], expected_byte);
}

// We verify the message length was inserted correctly by reversing the byte decomposition.
let len = 8 * message_size;
let mut reconstructed_len: Field = 0;
for i in MSG_SIZE_PTR..BLOCK_SIZE {
reconstructed_len = 256 * reconstructed_len + msg_block[i] as Field;
}
assert_eq(reconstructed_len, len as Field);
}

// Perform the final compression, then transform the `STATE` into `HASH`.
fn hash_final_block(msg_block: MSG_BLOCK, mut state: STATE) -> HASH {
let mut out_h: HASH = [0; 32]; // Digest as sequence of bytes
// Hash final padded block
state = sha256_compression(msg_u8_to_u32(msg_block), state);

Expand Down
7 changes: 7 additions & 0 deletions test_programs/benchmarks/bench_sha256_long/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "bench_sha256_long"
version = "0.1.0"
type = "bin"
authors = [""]

[dependencies]
Loading

0 comments on commit d606491

Please sign in to comment.