Skip to content

Commit

Permalink
feat: avoid unnecessary ssa passes while loop unrolling (#6509)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Nov 13, 2024
1 parent cfc22cb commit f81c649
Showing 1 changed file with 53 additions and 45 deletions.
98 changes: 53 additions & 45 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,47 @@ use fxhash::FxHashMap as HashMap;
impl Ssa {
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
#[tracing::instrument(level = "trace", skip(ssa))]
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
// Try to unroll loops first:
let mut unroll_errors;
(ssa, unroll_errors) = ssa.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
ssa = ssa.mem2reg();
ssa = ssa.simplify_cfg();
// Do another mem2reg after simplify_cfg to aid the next unroll
ssa = ssa.mem2reg();

// Unroll again
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
let acir_functions = ssa.functions.iter_mut().filter(|(_, func)| {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code.
!matches!(func.runtime(), RuntimeType::Brillig(_))
});

for (_, function) in acir_functions {
// Try to unroll loops first:
let mut unroll_errors = function.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();

// Unroll again
unroll_errors = function.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
}
}
}
Ok(ssa)
}

/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
/// Returns the ssa along with all unrolling errors encountered
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in self.functions.values_mut() {
function.try_to_unroll_loops(&mut errors);
}
(self, errors)
Ok(ssa)
}
}

impl Function {
pub(crate) fn try_to_unroll_loops(&mut self, errors: &mut Vec<RuntimeError>) {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code.
if !matches!(self.runtime(), RuntimeType::Brillig(_)) {
errors.extend(find_all_loops(self).unroll_each_loop(self));
}
fn try_to_unroll_loops(&mut self) -> Vec<RuntimeError> {
find_all_loops(self).unroll_each_loop(self)
}
}

Expand Down Expand Up @@ -507,11 +500,26 @@ impl<'f> LoopIteration<'f> {

#[cfg(test)]
mod tests {
use crate::ssa::{
function_builder::FunctionBuilder,
ir::{instruction::BinaryOp, map::Id, types::Type},
use crate::{
errors::RuntimeError,
ssa::{
function_builder::FunctionBuilder,
ir::{instruction::BinaryOp, map::Id, types::Type},
},
};

use super::Ssa;

/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
fn try_to_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in ssa.functions.values_mut() {
errors.extend(function.try_to_unroll_loops());
}
(ssa, errors)
}

#[test]
fn unroll_nested_loops() {
// fn main() {
Expand Down Expand Up @@ -630,7 +638,7 @@ mod tests {
// }
// The final block count is not 1 because unrolling creates some unnecessary jmps.
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
let (ssa, errors) = ssa.try_to_unroll_loops();
let (ssa, errors) = try_to_unroll_loops(ssa);
assert_eq!(errors.len(), 0, "All loops should be unrolled");
assert_eq!(ssa.main().reachable_blocks().len(), 5);
}
Expand Down Expand Up @@ -680,7 +688,7 @@ mod tests {
assert_eq!(ssa.main().reachable_blocks().len(), 4);

// Expected that we failed to unroll the loop
let (_, errors) = ssa.try_to_unroll_loops();
let (_, errors) = try_to_unroll_loops(ssa);
assert_eq!(errors.len(), 1, "Expected to fail to unroll loop");
}
}

0 comments on commit f81c649

Please sign in to comment.