From aa1d9968364380b30f8264c578b90e71cecf9abb Mon Sep 17 00:00:00 2001 From: Fynn Date: Thu, 23 Nov 2023 12:51:36 -0300 Subject: [PATCH] Use a library for reward collection --- src/FeeRewardsManager.sol | 83 +++++++++++++++++++++++++++--------- test/FeeRewardsManager.t.sol | 13 ++---- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/FeeRewardsManager.sol b/src/FeeRewardsManager.sol index ca48a06..ba4fc09 100644 --- a/src/FeeRewardsManager.sol +++ b/src/FeeRewardsManager.sol @@ -3,36 +3,29 @@ pragma solidity ^0.8.13; import "@openzeppelin/contracts/access/Ownable.sol"; -contract RewardsCollector is Ownable { +library CalculateAndSendRewards { + // Fee denominator, if `feeNominator = 500`, + // the tax is 500/10000 = 5/100 = 5%. + uint32 public constant FEE_DENOMINATOR = 10000; event CollectedReward( address withdrawalCredential, - uint256 withdrawalFee, + uint256 withdrawnAmount, address owner, uint256 ownerFee ); - // 1 - fee % will go to the user in this address. - address public withdrawalCredential; - - // Fee's numerator. - uint32 public feeNumerator; - - // Fee denominator, if `feeNumerator = 500`, - // the tax is 500/10000 = 5/100 = 5%. - uint32 public constant FEE_DENOMINATOR = 10000; - - // Allow receiving MEV and other rewards. - receive() external payable {} - - function collectRewards() public payable { - uint256 ownerAmount = (address(this).balance * feeNumerator) / + function calculateRewards( + uint32 feeNominator, + address owner, + address withdrawalCredential + ) public { + uint256 ownerAmount = (address(this).balance * feeNominator) / FEE_DENOMINATOR; uint256 returnedAmount = address(this).balance - ownerAmount; require( ownerAmount != 0 || returnedAmount != 0, "Nothing to distribute" ); - address owner = owner(); emit CollectedReward( withdrawalCredential, returnedAmount, @@ -47,13 +40,55 @@ contract RewardsCollector is Ownable { }(""); require(sent, "Failed to send Ether back to withdrawal credential"); } +} + +contract RewardsCollector { + event CollectedReward( + address withdrawalCredential, + uint256 withdrawalFee, + address owner, + uint256 ownerFee + ); + + // 1 - fee % will go to the user in this address. + address public withdrawalCredential; + + // Fee's numerator. + uint32 public feeNumerator; + + // This is the contract that created the `RewardsCollector`. + // Do not use owner here because this contract is going to be + // created multiple times for each `withdrawal credential` and + // we don't need any function for the ownership except when changing + // the fee. + address public parentContract; + + // Fee denominator, if `feeNumerator = 500`, + // the tax is 500/10000 = 5/100 = 5%. + uint32 public constant FEE_DENOMINATOR = 10000; + + // Allow receiving MEV and other rewards. + receive() external payable {} + + function collectRewards() public payable { + CalculateAndSendRewards.calculateRewards( + feeNumerator, + parentContract, + withdrawalCredential + ); + } constructor(address _withdrawalCredential, uint32 _feeNumerator) { withdrawalCredential = _withdrawalCredential; feeNumerator = _feeNumerator; + parentContract = msg.sender; } - function changeFee(uint32 _newFeeNumerator) public onlyOwner { + function changeFeeNumerator(uint32 _newFeeNumerator) public { + require( + msg.sender == parentContract, + "ChangeFee not called from parent contract" + ); feeNumerator = _newFeeNumerator; } } @@ -88,6 +123,12 @@ contract FeeRewardsManager is Ownable { return payable(addr); } + // Predicts the address of a new contract that will be a `fee_recipient` of + // an Ethereum validator. + // Given the `_withdrawalCredential` we can instantiate a contract that will + // be deployed at a deterministic address, calculated given the + // `_withdrawalCredential`, the current contract address and the current + // contract's bytecode. function predictFeeContractAddress( address _withdrawalCredential ) public view returns (address) { @@ -110,11 +151,11 @@ contract FeeRewardsManager is Ownable { return address(uint160(uint(hash))); } - function changeFee( + function changeFeeNumerator( address payable _feeContract, uint32 _newFee ) public onlyOwner { - RewardsCollector(_feeContract).changeFee(_newFee); + RewardsCollector(_feeContract).changeFeeNumerator(_newFee); } function batchCollectRewards( diff --git a/test/FeeRewardsManager.t.sol b/test/FeeRewardsManager.t.sol index a8da436..a70324c 100644 --- a/test/FeeRewardsManager.t.sol +++ b/test/FeeRewardsManager.t.sol @@ -11,12 +11,6 @@ contract ReentrantAttack { } } -contract ChangeOwnerContract { - fallback() external payable { - RewardsCollector(payable(msg.sender)).transferOwnership(address(0x200)); - } -} - contract WithdrawalContract { fallback() external payable {} } @@ -47,7 +41,7 @@ contract FeeRewardsTest is Test { // derived address has the parent's as owner assertEq( address(feeRewardsManager), - RewardsCollector(payable(addr)).owner() + RewardsCollector(payable(addr)).parentContract() ); uint256 amountInContract = address(addr).balance; @@ -114,7 +108,7 @@ contract FeeRewardsTest is Test { address addr = address( createWithdrawalSimulateRewards(address(100), 10 ether) ); - feeRewardsManager.changeFee(payable(addr), 10000); + feeRewardsManager.changeFeeNumerator(payable(addr), 10000); RewardsCollector(payable(addr)).collectRewards(); assertEq(address(100).balance, 0 ether); // We receive 100%. @@ -145,7 +139,8 @@ contract FeeRewardsTest is Test { rewards ); uint256 chorusAmount = (address(collector).balance * - uint256(collector.feeNumerator())) / collector.FEE_DENOMINATOR(); + uint256(collector.feeNumerator())) / + CalculateAndSendRewards.FEE_DENOMINATOR; uint256 withdrawalCredentialsAmount = address(collector).balance - chorusAmount; uint256 chorusBalanceBefore = address(feeRewardsManager).balance;