diff --git a/.changeset/shiny-poets-whisper.md b/.changeset/shiny-poets-whisper.md index cdef2391417..92497033acf 100644 --- a/.changeset/shiny-poets-whisper.md +++ b/.changeset/shiny-poets-whisper.md @@ -2,4 +2,4 @@ 'openzeppelin-solidity': minor --- -`Math`: Add `modExp` function that exposes the `EIP-198` precompile. +`Math`: Add `modExp` function that exposes the `EIP-198` precompile. Includes `uint256` and `bytes memory` versions. diff --git a/contracts/utils/math/Math.sol b/contracts/utils/math/Math.sol index be05506ec53..0f9dce94f44 100644 --- a/contracts/utils/math/Math.sol +++ b/contracts/utils/math/Math.sol @@ -3,7 +3,6 @@ pragma solidity ^0.8.20; -import {Address} from "../Address.sol"; import {Panic} from "../Panic.sol"; import {SafeCast} from "./SafeCast.sol"; @@ -289,11 +288,7 @@ library Math { function modExp(uint256 b, uint256 e, uint256 m) internal view returns (uint256) { (bool success, uint256 result) = tryModExp(b, e, m); if (!success) { - if (m == 0) { - Panic.panic(Panic.DIVISION_BY_ZERO); - } else { - revert Address.FailedInnerCall(); - } + Panic.panic(Panic.DIVISION_BY_ZERO); } return result; } @@ -335,6 +330,57 @@ library Math { } } + /** + * @dev Variant of {modExp} that supports inputs of arbitrary length. + */ + function modExp(bytes memory b, bytes memory e, bytes memory m) internal view returns (bytes memory) { + (bool success, bytes memory result) = tryModExp(b, e, m); + if (!success) { + Panic.panic(Panic.DIVISION_BY_ZERO); + } + return result; + } + + /** + * @dev Variant of {tryModExp} that supports inputs of arbitrary length. + */ + function tryModExp( + bytes memory b, + bytes memory e, + bytes memory m + ) internal view returns (bool success, bytes memory result) { + if (_zeroBytes(m)) return (false, new bytes(0)); + + uint256 mLen = m.length; + + // Encode call args in result and move the free memory pointer + result = abi.encodePacked(b.length, e.length, mLen, b, e, m); + + /// @solidity memory-safe-assembly + assembly { + let dataPtr := add(result, 0x20) + // Write result on top of args to avoid allocating extra memory. + success := staticcall(gas(), 0x05, dataPtr, mload(result), dataPtr, mLen) + // Overwrite the length. + // result.length > returndatasize() is guaranteed because returndatasize() == m.length + mstore(result, mLen) + // Set the memory pointer after the returned data. + mstore(0x40, add(dataPtr, mLen)) + } + } + + /** + * @dev Returns whether the provided byte array is zero. + */ + function _zeroBytes(bytes memory byteArray) private pure returns (bool) { + for (uint256 i = 0; i < byteArray.length; ++i) { + if (byteArray[i] != 0) { + return false; + } + } + return true; + } + /** * @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded * towards zero. diff --git a/test/helpers/math.js b/test/helpers/math.js index f8a1520ee1a..133254aecc2 100644 --- a/test/helpers/math.js +++ b/test/helpers/math.js @@ -3,8 +3,31 @@ const max = (...values) => values.slice(1).reduce((x, y) => (x > y ? x : y), val const min = (...values) => values.slice(1).reduce((x, y) => (x < y ? x : y), values.at(0)); const sum = (...values) => values.slice(1).reduce((x, y) => x + y, values.at(0)); +// Computes modexp without BigInt overflow for large numbers +function modExp(b, e, m) { + let result = 1n; + + // If e is a power of two, modexp can be calculated as: + // for (let result = b, i = 0; i < log2(e); i++) result = modexp(result, 2, m) + // + // Given any natural number can be written in terms of powers of 2 (i.e. binary) + // then modexp can be calculated for any e, by multiplying b**i for all i where + // binary(e)[i] is 1 (i.e. a power of two). + for (let base = b % m; e > 0n; base = base ** 2n % m) { + // Least significant bit is 1 + if (e % 2n == 1n) { + result = (result * base) % m; + } + + e /= 2n; // Binary pop + } + + return result; +} + module.exports = { min, max, sum, + modExp, }; diff --git a/test/utils/math/Math.t.sol b/test/utils/math/Math.t.sol index 7b49e8a8868..40f5986f4f6 100644 --- a/test/utils/math/Math.t.sol +++ b/test/utils/math/Math.t.sol @@ -226,6 +226,33 @@ contract MathTest is Test { } } + function testModExpMemory(uint256 b, uint256 e, uint256 m) public { + if (m == 0) { + vm.expectRevert(stdError.divisionError); + } + bytes memory result = Math.modExp(abi.encodePacked(b), abi.encodePacked(e), abi.encodePacked(m)); + assertEq(result.length, 0x20); + uint256 res = abi.decode(result, (uint256)); + assertLt(res, m); + assertEq(res, _nativeModExp(b, e, m)); + } + + function testTryModExpMemory(uint256 b, uint256 e, uint256 m) public { + (bool success, bytes memory result) = Math.tryModExp( + abi.encodePacked(b), + abi.encodePacked(e), + abi.encodePacked(m) + ); + if (success) { + assertEq(result.length, 0x20); // m is a uint256, so abi.encodePacked(m).length is 0x20 + uint256 res = abi.decode(result, (uint256)); + assertLt(res, m); + assertEq(res, _nativeModExp(b, e, m)); + } else { + assertEq(result.length, 0); + } + } + function _nativeModExp(uint256 b, uint256 e, uint256 m) private pure returns (uint256) { if (m == 1) return 0; uint256 r = 1; diff --git a/test/utils/math/Math.test.js b/test/utils/math/Math.test.js index b75d3b58858..bce02610f97 100644 --- a/test/utils/math/Math.test.js +++ b/test/utils/math/Math.test.js @@ -4,12 +4,19 @@ const { loadFixture } = require('@nomicfoundation/hardhat-network-helpers'); const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic'); const { Rounding } = require('../../helpers/enums'); -const { min, max } = require('../../helpers/math'); +const { min, max, modExp } = require('../../helpers/math'); const { generators } = require('../../helpers/random'); +const { range } = require('../../../scripts/helpers'); +const { product } = require('../../helpers/iterate'); const RoundingDown = [Rounding.Floor, Rounding.Trunc]; const RoundingUp = [Rounding.Ceil, Rounding.Expand]; +const bytes = (value, width = undefined) => ethers.Typed.bytes(ethers.toBeHex(value, width)); +const uint256 = value => ethers.Typed.uint256(value); +bytes.zero = '0x'; +uint256.zero = 0n; + async function testCommutative(fn, lhs, rhs, expected, ...extra) { expect(await fn(lhs, rhs, ...extra)).to.deep.equal(expected); expect(await fn(rhs, lhs, ...extra)).to.deep.equal(expected); @@ -141,24 +148,6 @@ describe('Math', function () { }); }); - describe('tryModExp', function () { - it('is correctly returning true and calculating modulus', async function () { - const base = 3n; - const exponent = 200n; - const modulus = 50n; - - expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([true, base ** exponent % modulus]); - }); - - it('is correctly returning false when modulus is 0', async function () { - const base = 3n; - const exponent = 200n; - const modulus = 0n; - - expect(await this.mock.$tryModExp(base, exponent, modulus)).to.deep.equal([false, 0n]); - }); - }); - describe('max', function () { it('is correctly detected in both position', async function () { await testCommutative(this.mock.$max, 1234n, 5678n, max(1234n, 5678n)); @@ -354,20 +343,79 @@ describe('Math', function () { }); describe('modExp', function () { - it('is correctly calculating modulus', async function () { - const base = 3n; - const exponent = 200n; - const modulus = 50n; + for (const [name, type] of Object.entries({ uint256, bytes })) { + describe(`with ${name} inputs`, function () { + it('is correctly calculating modulus', async function () { + const b = 3n; + const e = 200n; + const m = 50n; + + expect(await this.mock.$modExp(type(b), type(e), type(m))).to.equal(type(b ** e % m).value); + }); - expect(await this.mock.$modExp(base, exponent, modulus)).to.equal(base ** exponent % modulus); + it('is correctly reverting when modulus is zero', async function () { + const b = 3n; + const e = 200n; + const m = 0n; + + await expect(this.mock.$modExp(type(b), type(e), type(m))).to.be.revertedWithPanic( + PANIC_CODES.DIVISION_BY_ZERO, + ); + }); + }); + } + + describe('with large bytes inputs', function () { + for (const [[b, log2b], [e, log2e], [m, log2m]] of product( + range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]), + range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]), + range(320, 512, 64).map(e => [2n ** BigInt(e) + 1n, e]), + )) { + it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () { + const mLength = ethers.dataLength(ethers.toBeHex(m)); + + expect(await this.mock.$modExp(bytes(b), bytes(e), bytes(m))).to.equal(bytes(modExp(b, e, m), mLength).value); + }); + } }); + }); + + describe('tryModExp', function () { + for (const [name, type] of Object.entries({ uint256, bytes })) { + describe(`with ${name} inputs`, function () { + it('is correctly calculating modulus', async function () { + const b = 3n; + const e = 200n; + const m = 50n; + + expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([true, type(b ** e % m).value]); + }); - it('is correctly reverting when modulus is zero', async function () { - const base = 3n; - const exponent = 200n; - const modulus = 0n; + it('is correctly reverting when modulus is zero', async function () { + const b = 3n; + const e = 200n; + const m = 0n; - await expect(this.mock.$modExp(base, exponent, modulus)).to.be.revertedWithPanic(PANIC_CODES.DIVISION_BY_ZERO); + expect(await this.mock.$tryModExp(type(b), type(e), type(m))).to.deep.equal([false, type.zero]); + }); + }); + } + + describe('with large bytes inputs', function () { + for (const [[b, log2b], [e, log2e], [m, log2m]] of product( + range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]), + range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]), + range(320, 513, 64).map(e => [2n ** BigInt(e) + 1n, e]), + )) { + it(`calculates b ** e % m (b=2**${log2b}+1) (e=2**${log2e}+1) (m=2**${log2m}+1)`, async function () { + const mLength = ethers.dataLength(ethers.toBeHex(m)); + + expect(await this.mock.$tryModExp(bytes(b), bytes(e), bytes(m))).to.deep.equal([ + true, + bytes(modExp(b, e, m), mLength).value, + ]); + }); + } }); });