Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a library for reward collection #3

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 53 additions & 18 deletions src/FeeRewardsManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,29 @@ pragma solidity ^0.8.13;

import "@openzeppelin/contracts/access/Ownable.sol";

contract RewardsCollector is Ownable {
library CalculateAndSendRewards {
enriquefynn marked this conversation as resolved.
Show resolved Hide resolved
// Fee denominator, if `feeNominator = 500`,
// the tax is 500/10000 = 5/100 = 5%.
uint32 public constant FEE_DENOMINATOR = 10000;
event CollectedReward(
address withdrawalCredential,
uint256 withdrawnAmount,
address owner,
uint256 ownerFee
);

// 1 - fee % will go to the user in this address.
address public withdrawalCredential;

// Fee's numerator.
uint32 public feeNumerator;

// Fee denominator, if `feeNumerator = 500`,
// the tax is 500/10000 = 5/100 = 5%.
uint32 public constant FEE_DENOMINATOR = 10000;

// Allow receiving MEV and other rewards.
receive() external payable {}

function collectRewards() public payable {
uint256 ownerAmount = (address(this).balance * feeNumerator) /
function calculateRewards(
uint32 feeNominator,
address owner,
address withdrawalCredential
) public {
uint256 ownerAmount = (address(this).balance * feeNominator) /
FEE_DENOMINATOR;
uint256 returnedAmount = address(this).balance - ownerAmount;
require(
ownerAmount != 0 || returnedAmount != 0,
"Nothing to distribute"
);
address owner = owner();
emit CollectedReward(
withdrawalCredential,
returnedAmount,
Expand All @@ -47,13 +40,55 @@ contract RewardsCollector is Ownable {
}("");
require(sent, "Failed to send Ether back to withdrawal credential");
}
}

contract RewardsCollector {
event CollectedReward(
address withdrawalCredential,
uint256 withdrawalFee,
address owner,
uint256 ownerFee
);

// 1 - fee % will go to the user in this address.
address public withdrawalCredential;

// Fee's numerator.
uint32 public feeNumerator;

// This is the contract that created the `RewardsCollector`.
// Do not use owner here because this contract is going to be
// created multiple times for each `withdrawal credential` and
// we don't need any function for the ownership except when changing
// the fee.
address public parentContract;

// Fee denominator, if `feeNumerator = 500`,
// the tax is 500/10000 = 5/100 = 5%.
uint32 public constant FEE_DENOMINATOR = 10000;

// Allow receiving MEV and other rewards.
receive() external payable {}

function collectRewards() public payable {
CalculateAndSendRewards.calculateRewards(
feeNumerator,
parentContract,
withdrawalCredential
);
}

constructor(address _withdrawalCredential, uint32 _feeNumerator) {
withdrawalCredential = _withdrawalCredential;
feeNumerator = _feeNumerator;
parentContract = msg.sender;
}

function changeFeeNumerator(uint32 _newFeeNumerator) public onlyOwner {
function changeFeeNumerator(uint32 _newFeeNumerator) public {
require(
msg.sender == parentContract,
"ChangeFee not called from parent contract"
);
feeNumerator = _newFeeNumerator;
}
}
Expand Down
9 changes: 6 additions & 3 deletions test/FeeRewardsManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ contract ReentrantAttack {

contract ChangeOwnerContract {
fallback() external payable {
RewardsCollector(payable(msg.sender)).transferOwnership(address(0x200));
FeeRewardsManager(payable(msg.sender)).transferOwnership(
address(0x200)
);
}
}

Expand Down Expand Up @@ -47,7 +49,7 @@ contract FeeRewardsTest is Test {
// derived address has the parent's as owner
assertEq(
address(feeRewardsManager),
RewardsCollector(payable(addr)).owner()
RewardsCollector(payable(addr)).parentContract()
);

uint256 amountInContract = address(addr).balance;
Expand Down Expand Up @@ -155,7 +157,8 @@ contract FeeRewardsTest is Test {
rewards
);
uint256 chorusAmount = (address(collector).balance *
uint256(collector.feeNumerator())) / collector.FEE_DENOMINATOR();
uint256(collector.feeNumerator())) /
CalculateAndSendRewards.FEE_DENOMINATOR;
uint256 withdrawalCredentialsAmount = address(collector).balance -
chorusAmount;
uint256 chorusBalanceBefore = address(feeRewardsManager).balance;
Expand Down