Skip to content

Commit

Permalink
Fix Self Destruct issue in StakingNode.sol
Browse files Browse the repository at this point in the history
  • Loading branch information
xhad committed Mar 29, 2024
1 parent 707f357 commit f87df07
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 56 deletions.
27 changes: 11 additions & 16 deletions src/StakingNode.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {IBeacon} from "@openzeppelin/contracts/proxy/beacon/IBeacon.sol";
import {IEigenPodManager} from "./external/eigenlayer/v0.1.0/interfaces/IEigenPodManager.sol";
import {IEigenPod} from "./external/eigenlayer/v0.1.0/interfaces/IEigenPod.sol";
import {IDelegationManager} from "./external/eigenlayer/v0.1.0/interfaces/IDelegationManager.sol";
import {IDelayedWithdrawalRouter} from "./external/eigenlayer/v0.1.0/interfaces/IDelayedWithdrawalRouter.sol";
import {IStrategy, IStrategyManager} from "./external/eigenlayer/v0.1.0/interfaces/IStrategyManager.sol";
import {BeaconChainProofs} from "./external/eigenlayer/v0.1.0/BeaconChainProofs.sol";
import {IStakingNodesManager} from "./interfaces/IStakingNodesManager.sol";
Expand Down Expand Up @@ -36,7 +35,6 @@ contract StakingNode is IStakingNode, StakingNodeEvents, ReentrancyGuardUpgradea
error NotStakingNodesAdmin();
error ETHDepositorNotDelayedWithdrawalRouter();
error WithdrawalPrincipalAmountTooHigh(uint256 withdrawnValidatorPrincipal, uint256 allocatedETH);
error ValidatorPrincipalExceedsTotalClaimable(uint256 withdrawnValidatorPrincipal, uint256 claimableAmount);
error ClaimAmountTooLow(uint256 expected, uint256 actual);
error ZeroAddress();
error NotStakingNodesManager();
Expand Down Expand Up @@ -147,27 +145,24 @@ contract StakingNode is IStakingNode, StakingNodeEvents, ReentrancyGuardUpgradea
uint256 expectedETHBalance
) public nonReentrant onlyAdmin {

uint256 claimableAmount = address(this).balance;
uint256 balance = address(this).balance;

if (totalValidatorPrincipal > allocatedETH) {
revert WithdrawalPrincipalAmountTooHigh(totalValidatorPrincipal, allocatedETH);
// check for any race conditions with balances by passing in the expected balance
if (balance != expectedETHBalance) {
revert UnexpectedETHBalance(balance, expectedETHBalance);
}

if (totalValidatorPrincipal > claimableAmount) {
revert ValidatorPrincipalExceedsTotalClaimable(totalValidatorPrincipal, claimableAmount);
// check the desired balance of validator principal is available here
if (balance < totalValidatorPrincipal) {
revert WithdrawalPrincipalAmountTooHigh(totalValidatorPrincipal, balance);
}

// This check ensures that the actual balance of the contract matches the expected balance after withdrawals.
// Ensures that the totalValidatorPrincipal is not out of sync with the address(this).balance
// by the time it reaches the on-chain
if (expectedETHBalance != claimableAmount) {
revert UnexpectedETHBalance(claimableAmount, expectedETHBalance);
}
// substract validator principal
// substract withdrawn validator principal from the allocated balance
allocatedETH -= totalValidatorPrincipal;

stakingNodesManager.processWithdrawnETH{value: claimableAmount}(nodeId, totalValidatorPrincipal);
emit WithdrawalsProcessed(claimableAmount, totalValidatorPrincipal, allocatedETH);
// push the entire balance here to the StakingNodesManager
stakingNodesManager.processWithdrawnETH{value: balance}(nodeId, totalValidatorPrincipal);
emit WithdrawalsProcessed(balance, totalValidatorPrincipal, allocatedETH);
}


Expand Down
2 changes: 1 addition & 1 deletion src/interfaces/IynETH.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {IERC20Permit} from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Permit.sol";

interface IynETH is IERC20 {
function withdrawETH(uint ethAmount) external;
function withdrawETH(uint256 ethAmount) external;
function processWithdrawnETH() external payable;
function receiveRewards() external payable;
function updateDepositsPaused(bool paused) external;
Expand Down
34 changes: 1 addition & 33 deletions test/foundry/integration/StakingNode.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ contract StakingNodeEigenPod is StakingNodeTestBase {
rewardsDistributor.processRewards();

uint256 fee = uint256(rewardsDistributor.feesBasisPoints());
uint finalRewardsReceived = rewardsAmount - (rewardsAmount * fee / 10000);
uint256 finalRewardsReceived = rewardsAmount - (rewardsAmount * fee / 10000);

// Assert total assets after claiming delayed withdrawals
uint256 totalAssets = yneth.totalAssets();
Expand Down Expand Up @@ -227,38 +227,6 @@ contract StakingNodeWithdrawWithoutRestaking is StakingNodeTestBase {
assertEq(rewardsAmount, expectedRewards, "Rewards amount does not match expected value");
}

function testValidatorPrincipalExceedsTotalClaimable() public {

uint256 activeValidators = 5;

uint256 depositAmount = activeValidators * 32 ether;
uint256 validatorPrincipal = depositAmount; // Total principal for all validators

(IStakingNode stakingNodeInstance, IEigenPod eigenPodInstance) = setupStakingNode(depositAmount);

// Simulate rewards being sweeped into the StakingNode's balance
uint256 rewardsSweeped = 3 * 32 ether;
address payable eigenPodAddress = payable(address(eigenPodInstance));
vm.deal(eigenPodAddress, rewardsSweeped);

// Trigger withdraw before restaking successfully
vm.prank(actors.STAKING_NODES_ADMIN);
stakingNodeInstance.withdrawBeforeRestaking();

// Simulate time passing for withdrawal delay
IDelayedWithdrawalRouter delayedWithdrawalRouter = stakingNodesManager.delayedWithdrawalRouter();
vm.roll(block.number + delayedWithdrawalRouter.withdrawalDelayBlocks() + 1);

delayedWithdrawalRouter.claimDelayedWithdrawals(address(stakingNodeInstance), type(uint256).max);

uint256 tooLargeValidatorPrincipal = validatorPrincipal;

// Attempt to claim withdrawals with a validator principal that exceeds total claimable amount
vm.prank(actors.STAKING_NODES_ADMIN);
vm.expectRevert(abi.encodeWithSelector(StakingNode.ValidatorPrincipalExceedsTotalClaimable.selector, tooLargeValidatorPrincipal, rewardsSweeped));
stakingNodeInstance.processWithdrawals(tooLargeValidatorPrincipal, address(stakingNodeInstance).balance);
}

function testWithdrawalPrincipalAmountTooHigh() public {

uint256 activeValidators = 5;
Expand Down
91 changes: 85 additions & 6 deletions test/foundry/scenarios/ynETH.spec.sol
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ contract YnETHScenarioTest3 is IntegrationBaseTest {
}
}

event LogUint(string message, uint256 value);

contract YnETHScenarioTest8 is IntegrationBaseTest, YnETHScenarioTest3 {

/**
Expand Down Expand Up @@ -348,15 +350,19 @@ contract YnETHScenarioTest10 is IntegrationBaseTest, YnETHScenarioTest3 {
*/

function test_ynETH_Scenario_9_Self_Destruct_Attack() public {

address sender = address(this);

uint256 previousTotalDeposited = yneth.totalDepositedInPool();
uint256 previousTotalShares = yneth.totalSupply();


// Deposit 32 ETH to ynETH and create a Staking Node with a Validator
(IStakingNode stakingNode,) = depositEth_and_createValidator();

// Amount of ether to send via self-destruct
uint256 amountToSend = 1 ether;
// Ensure the test contract has enough ether to send
vm.deal(sender, amountToSend);

// Ensure the test contract has enough ether to send, user1 comes from Test3
vm.deal(user1, amountToSend);

// Address to send ether to - for example, the stakingNode or another address
address payable target = payable(address(stakingNode)); // or any other target address
Expand All @@ -365,9 +371,82 @@ contract YnETHScenarioTest10 is IntegrationBaseTest, YnETHScenarioTest3 {
// The SelfDestructSender contract is created with the amountToSend and immediately self-destructs,
// sending its balance to the target address.
address(new SelfDestructSender{value: amountToSend}(target));

log_balances(stakingNode);

assertEq(address(yneth).balance, 0, "yneth.balance != 0");
assertEq(address(stakingNode).balance, 1 ether, "stakingNode.balance != 1 ether");
assertEq(address(consensusLayerReceiver).balance, 0, "consensusLayerReceiver.balance != 0");
assertEq(address(executionLayerReceiver).balance, 0, "executionLayerReceiver.balance != 0");

vm.startPrank(actors.STAKING_NODES_ADMIN);
withdraw_principal(stakingNode);
stakingNode.processWithdrawals(32 ether, 33 ether + 1 wei);
vm.stopPrank();

vm.prank(actors.STAKING_NODES_ADMIN);
stakingNode.processWithdrawals(32 ether, address(stakingNode).balance);
log_balances(stakingNode);

assertEq(address(yneth).balance, 32 ether, "yneth.balance != 32 ether");
assertEq(address(stakingNode).balance, 0, "stakingNode.balance != 0");
assertEq(address(consensusLayerReceiver).balance, 1 ether + 1 wei, "consensusLayerReceiver.balance != 0");
assertEq(address(executionLayerReceiver).balance, 0, "executionLayerReceiver.balance != 0");

uint256 userAmount = 32 ether;
uint256 userShares = yneth.balanceOf(user1);

runInvariants(
user1,
previousTotalDeposited,
previousTotalShares,
userAmount,
userShares
);


}

function withdraw_principal(IStakingNode stakingNode) public {

// send concensus rewards to eigen pod
uint256 amount = 32 ether + 1 wei;
IEigenPod eigenPod = IEigenPod(stakingNode.eigenPod());
uint256 initialPodBalance = address(eigenPod).balance;
vm.deal(address(eigenPod), amount);
assertEq(address(eigenPod).balance, initialPodBalance + amount);

stakingNode.withdrawBeforeRestaking();

// There should be a delayedWithdraw on the DelayedWithdrawalRouter
IDelayedWithdrawalRouter withdrawalRouter = IDelayedWithdrawalRouter(chainAddresses.eigenlayer.DELAYED_WITHDRAWAL_ROUTER_ADDRESS);
IDelayedWithdrawalRouter.DelayedWithdrawal[] memory delayedWithdrawals = withdrawalRouter.getUserDelayedWithdrawals(address(stakingNode));
assertEq(delayedWithdrawals.length, 1);
assertEq(delayedWithdrawals[0].amount, amount);

// Because of the delay, the delayedWithdrawal should not be claimable yet
IDelayedWithdrawalRouter.DelayedWithdrawal[] memory claimableDelayedWithdrawals = withdrawalRouter.getClaimableUserDelayedWithdrawals(address(stakingNode));
assertEq(claimableDelayedWithdrawals.length, 0);

// Move ahead in time to make the delayedWithdrawal claimable
vm.roll(block.number + withdrawalRouter.withdrawalDelayBlocks() + 1);
IDelayedWithdrawalRouter.DelayedWithdrawal[] memory claimableDelayedWithdrawalsWarp = withdrawalRouter.getClaimableUserDelayedWithdrawals(address(stakingNode));
assertEq(claimableDelayedWithdrawalsWarp.length, 1);
assertEq(claimableDelayedWithdrawalsWarp[0].amount, amount, "claimableDelayedWithdrawalsWarp[0].amount != 3 ether");

withdrawalRouter.claimDelayedWithdrawals(address(stakingNode), type(uint256).max);
}

function log_balances (IStakingNode stakingNode) public {
emit LogUint("yneth.balance", address(yneth).balance);
emit LogUint("stakingNode.balance", address(stakingNode).balance);
emit LogUint("consensusReciever.balance", address(consensusLayerReceiver).balance);
emit LogUint("executionReciever.balance", address(executionLayerReceiver).balance);
}

function runInvariants(address user, uint256 previousTotalDeposited, uint256 previousTotalShares, uint256 userAmount, uint256 userShares) public view {
Invariants.totalDepositIntegrity(yneth.totalDepositedInPool(), previousTotalDeposited, userAmount);
Invariants.totalAssetsIntegrity(yneth.totalAssets(), previousTotalDeposited, userAmount);
Invariants.shareMintIntegrity(yneth.totalSupply(), previousTotalShares, userShares);
Invariants.userSharesIntegrity(yneth.balanceOf(user), 0, userShares);
}
}

Expand Down

0 comments on commit f87df07

Please sign in to comment.