From 1cd2587bf67143832f76f90c25aecca1a46b1284 Mon Sep 17 00:00:00 2001 From: jfecher Date: Thu, 10 Oct 2024 12:20:16 +0100 Subject: [PATCH] fix!: Integer division is not the inverse of integer multiplication (#6243) # Description ## Problem\* Resolves https://github.com/noir-lang/noir/issues/6242 ## Summary\* Making this PR for visibility to show that fixing this issue breaks our serialize test. Notably this fix allows rounding arithmetic generics using the `/ N * N` pattern which was simplified away previously. ## Additional Context I initially removed the cancellation of / and * entirely but found that this breaks even more tests. So I added `approx_inverse` in a few cases that only involve constants that I thought may be correct but we still fail `serialize` even with these cases. ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- compiler/noirc_frontend/src/hir_def/types.rs | 13 +++++++- .../src/hir_def/types/arithmetic.rs | 22 ++++++------- compiler/noirc_frontend/src/tests.rs | 31 +++++++++++++++++++ .../arithmetic_generics/src/main.nr | 13 +++++--- 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 69e5066c596..fa2a455c06d 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1761,7 +1761,7 @@ impl Type { Err(UnificationError) } } else if let InfixExpr(lhs, op, rhs) = other { - if let Some(inverse) = op.inverse() { + if let Some(inverse) = op.approx_inverse() { // Handle cases like `4 = a + b` by trying to solve to `a = 4 - b` let new_type = InfixExpr( Box::new(Constant(*value, kind.clone())), @@ -2569,6 +2569,17 @@ impl BinaryTypeOperator { /// Return the operator that will "undo" this operation if applied to the rhs fn inverse(self) -> Option { + match self { + BinaryTypeOperator::Addition => Some(BinaryTypeOperator::Subtraction), + BinaryTypeOperator::Subtraction => Some(BinaryTypeOperator::Addition), + BinaryTypeOperator::Multiplication => None, + BinaryTypeOperator::Division => None, + BinaryTypeOperator::Modulo => None, + } + } + + /// Return the operator that will "undo" this operation if applied to the rhs + fn approx_inverse(self) -> Option { match self { BinaryTypeOperator::Addition => Some(BinaryTypeOperator::Subtraction), BinaryTypeOperator::Subtraction => Some(BinaryTypeOperator::Addition), diff --git a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs index 0eee7dbf824..81f638ebca4 100644 --- a/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs +++ b/compiler/noirc_frontend/src/hir_def/types/arithmetic.rs @@ -120,7 +120,7 @@ impl Type { /// Precondition: `lhs & rhs are in canonical form` /// /// - Simplifies `(N +/- M) -/+ M` to `N` - /// - Simplifies `(N */÷ M) ÷/* M` to `N` + /// - Simplifies `(N * M) ÷ M` to `N` fn try_simplify_non_constants_in_lhs( lhs: &Type, op: BinaryTypeOperator, @@ -132,7 +132,10 @@ impl Type { // Note that this is exact, syntactic equality, not unification. // `rhs` is expected to already be in canonical form. - if l_op.inverse() != Some(op) || l_rhs.canonicalize() != *rhs { + if l_op.approx_inverse() != Some(op) + || l_op == BinaryTypeOperator::Division + || l_rhs.canonicalize() != *rhs + { return None; } @@ -199,7 +202,8 @@ impl Type { /// Precondition: `lhs & rhs are in canonical form` /// /// - Simplifies `(N +/- C1) +/- C2` to `N +/- (C1 +/- C2)` if C1 and C2 are constants. - /// - Simplifies `(N */÷ C1) */÷ C2` to `N */÷ (C1 */÷ C2)` if C1 and C2 are constants. + /// - Simplifies `(N * C1) ÷ C2` to `N * (C1 ÷ C2)` if C1 and C2 are constants which divide + /// without a remainder. fn try_simplify_partial_constants( lhs: &Type, mut op: BinaryTypeOperator, @@ -218,12 +222,8 @@ impl Type { let constant = Type::Constant(result, lhs.infix_kind(rhs)); Some(Type::InfixExpr(l_type, l_op, Box::new(constant))) } - (Multiplication | Division, Multiplication | Division) => { - // If l_op is a division we want to inverse the rhs operator. - if l_op == Division { - op = op.inverse()?; - } - + (Multiplication, Division) => { + // We need to ensure the result divides evenly to preserve integer division semantics let divides_evenly = !lhs.infix_kind(rhs).is_type_level_field_element() && l_const.to_i128().checked_rem(r_const.to_i128()) == Some(0); @@ -248,7 +248,7 @@ impl Type { bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { if let Type::InfixExpr(lhs_a, op_a, rhs_a) = self { - if let Some(inverse) = op_a.inverse() { + if let Some(inverse) = op_a.approx_inverse() { let kind = lhs_a.infix_kind(rhs_a); if let Some(rhs_a_value) = rhs_a.evaluate_to_field_element(&kind) { let rhs_a = Box::new(Type::Constant(rhs_a_value, kind)); @@ -264,7 +264,7 @@ impl Type { } if let Type::InfixExpr(lhs_b, op_b, rhs_b) = other { - if let Some(inverse) = op_b.inverse() { + if let Some(inverse) = op_b.approx_inverse() { let kind = lhs_b.infix_kind(rhs_b); if let Some(rhs_b_value) = rhs_b.evaluate_to_field_element(&kind) { let rhs_b = Box::new(Type::Constant(rhs_b_value, kind)); diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index f190ef38bab..8b54095973c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3357,3 +3357,34 @@ fn error_if_attribute_not_in_scope() { CompilationError::ResolverError(ResolverError::AttributeFunctionNotInScope { .. }) )); } + +#[test] +fn arithmetic_generics_rounding_pass() { + let src = r#" + fn main() { + // 3/2*2 = 2 + round::<3, 2>([1, 2]); + } + + fn round(_x: [Field; N / M * M]) {} + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 0); +} + +#[test] +fn arithmetic_generics_rounding_fail() { + let src = r#" + fn main() { + // Do not simplify N/M*M to just N + // This should be 3/2*2 = 2, not 3 + round::<3, 2>([1, 2, 3]); + } + + fn round(_x: [Field; N / M * M]) {} + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); +} diff --git a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr index 9a002356144..4a057a75e43 100644 --- a/test_programs/compile_success_empty/arithmetic_generics/src/main.nr +++ b/test_programs/compile_success_empty/arithmetic_generics/src/main.nr @@ -117,9 +117,12 @@ fn test_constant_folding() { // N * C1 / C2 = N * (C1 / C2) let _: W = W:: {}; - + // This case is invalid due to integer division + // If N does not divide evenly with 10 then we cannot simplify it. + // e.g. 15 / 10 * 2 = 2 versus 15 / 5 = 3 + // // N / C1 * C2 = N / (C1 / C2) - let _: W = W:: {}; + // let _: W = W:: {}; } fn test_non_constant_folding() { @@ -131,7 +134,9 @@ fn test_non_constant_folding() { // N * M / M = N let _: W = W:: {}; - + // This case is not true due to integer division rounding! + // Consider 5 / 2 * 2 which should equal 4, not 5 + // // N / M * M = N - let _: W = W:: {}; + // let _: W = W:: {}; }