Skip to content
This repository has been archived by the owner on May 23, 2023. It is now read-only.

Commit

Permalink
Negative test for a reentrant attack on the core relayer forward mech…
Browse files Browse the repository at this point in the history
…anism (#83)

* Modifies the relayer simulation to be easier to use in negative tests.

* Adds negative test for a reentrancy attack on the forward mechanism.

* `forge fmt` run.
  • Loading branch information
scnale committed Feb 10, 2023
1 parent 612b159 commit 1ce645d
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 10 deletions.
66 changes: 66 additions & 0 deletions ethereum/contracts/mock/AttackForwardIntegration.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.17;

import "@openzeppelin/contracts/token/ERC20/ERC20.sol";

import "../interfaces/IWormhole.sol";
import "../interfaces/IWormholeReceiver.sol";
import "../interfaces/ICoreRelayer.sol";

/**
* This contract is a malicious "integration" that attempts to attack the forward mechanism.
*/
contract AttackForwardIntegration is IWormholeReceiver {
mapping(bytes32 => bool) consumedMessages;
address attackerReward;
IWormhole wormhole;
ICoreRelayer core_relayer;
uint32 nonce = 1;
uint16 targetChainId;

// Capture 30k gas for fees
// This just needs to be enough to pay for the call to the destination address.
uint32 SAFE_DELIVERY_GAS_CAPTURE = 30000;

constructor(IWormhole initWormhole, ICoreRelayer initCoreRelayer, uint16 chainId, address initAttackerReward) {
attackerReward = initAttackerReward;
wormhole = initWormhole;
core_relayer = initCoreRelayer;
targetChainId = chainId;
}

// This is the function which receives all messages from the remote contracts.
function receiveWormholeMessages(bytes[] memory vaas, bytes[] memory additionalData) public payable override {
// Do nothing. The attacker doesn't care about this message; he sends it himself.
}

receive() external payable {
// Request forward from the relayer network
// The core relayer could in principle accept the request due to this being the target of the message at the same time as being the refund address.
// Note that, if succesful, this forward request would be processed after the time for processing forwards is past.
// Thus, the request would "linger" in the forward request cache and be attended to in the next delivery.
requestForward(targetChainId, toWormholeFormat(attackerReward));
}

function requestForward(uint16 targetChain, bytes32 attackerRewardAddress) internal {
uint256 computeBudget = core_relayer.quoteGasDeliveryFee(
targetChain, SAFE_DELIVERY_GAS_CAPTURE, core_relayer.getDefaultRelayProvider()
);

ICoreRelayer.DeliveryRequest memory request = ICoreRelayer.DeliveryRequest({
targetChain: targetChain,
targetAddress: attackerRewardAddress,
// All remaining funds will be returned to the attacker
refundAddress: attackerRewardAddress,
computeBudget: computeBudget,
applicationBudget: 0,
relayParameters: core_relayer.getDefaultRelayParams()
});

core_relayer.requestForward{value: computeBudget}(request, nonce, core_relayer.getDefaultRelayProvider());
}

function toWormholeFormat(address addr) public pure returns (bytes32 whFormat) {
return bytes32(uint256(uint160(addr)));
}
}
137 changes: 127 additions & 10 deletions ethereum/forge-test/CoreRelayer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {Wormhole} from "../wormhole/ethereum/contracts/Wormhole.sol";
import {IWormhole} from "../contracts/interfaces/IWormhole.sol";
import {WormholeSimulator} from "./WormholeSimulator.sol";
import {IWormholeReceiver} from "../contracts/interfaces/IWormholeReceiver.sol";
import {AttackForwardIntegration} from "../contracts/mock/AttackForwardIntegration.sol";
import {MockRelayerIntegration} from "../contracts/mock/MockRelayerIntegration.sol";
import "../contracts/libraries/external/BytesLib.sol";

Expand Down Expand Up @@ -495,6 +496,106 @@ contract TestCoreRelayer is Test {
assertTrue(keccak256(setup.source.integration.getMessage()) == keccak256(bytes("received!")));
}

function testAttackForwardRequestCache(GasParameters memory gasParams, FeeParameters memory feeParams) public {
// General idea:
// 1. Attacker sets up a malicious integration contract in the target chain.
// 2. Attacker requests a message send to `target` chain.
// The message destination and the refund address are both the malicious integration contract in the target chain.
// 3. The delivery of the message triggers a refund to the malicious integration contract.
// 4. During the refund, the integration contract activates the forwarding mechanism.
// This is allowed due to the integration contract also being the target of the delivery.
// 5. The forward request is left as is in the `CoreRelayer` state.
// 6. The next message (i.e. the victim's message) delivery on `target` chain, from any relayer, using any `RelayProvider` and any integration contract,
// will see the forward request placed by the malicious integration contract and act on it.
// Caveat: the delivery of the victim's message must not invoke the forwarding mechanism for the attack test to be meaningful.
//
// In essence, this tries to attack the shared forwarding request cache present in the contract state.
// This attack doesn't work thanks to the check inside the `requestForward` function that only allows requesting a forward when there is a delivery being processed.

StandardSetupTwoChains memory setup = standardAssumeAndSetupTwoChains(gasParams, feeParams, 1000000);

// Collected funds from the attack are meant to be sent here.
address attackerSourceAddress =
address(uint160(uint256(keccak256(abi.encodePacked(bytes("attackerAddress"), setup.sourceChainId)))));
assertTrue(attackerSourceAddress.balance == 0);

// Borrowed assumes from testForward. They should help since this test is similar.
vm.assume(
uint256(1) * gasParams.targetGasPrice * feeParams.targetNativePrice
> uint256(1) * gasParams.sourceGasPrice * feeParams.sourceNativePrice
);

vm.assume(
setup.source.coreRelayer.quoteGasDeliveryFee(
setup.targetChainId, gasParams.targetGasLimit, setup.source.relayProvider
) < uint256(2) ** 222
);
vm.assume(
setup.target.coreRelayer.quoteGasDeliveryFee(setup.sourceChainId, 500000, setup.target.relayProvider)
< uint256(2) ** 222 / feeParams.targetNativePrice
);

// Estimate the cost based on the initialized values
uint256 computeBudget = setup.source.coreRelayer.quoteGasDeliveryFee(
setup.targetChainId, gasParams.targetGasLimit, setup.source.relayProvider
);

{
AttackForwardIntegration attackerContract =
new AttackForwardIntegration(setup.target.wormhole, setup.target.coreRelayer, setup.targetChainId, attackerSourceAddress);
bytes memory attackMsg = "attack";

vm.recordLogs();

// The attacker requests the message to be sent to the malicious contract.
// It is critical that the refund and destination (aka integrator) addresses are the same.
setup.source.integration.sendMessage{value: computeBudget + 2 * setup.source.wormhole.messageFee()}(
attackMsg, setup.targetChainId, address(attackerContract), address(attackerContract)
);

// The relayer triggers the call to the malicious contract.
genericRelayer(setup.sourceChainId, 2);

// The message delivery should fail
assertTrue(keccak256(setup.target.integration.getMessage()) != keccak256(attackMsg));
}

{
// Now one victim sends their message. It doesn't need to be from the same source chain.
// What's necessary is that a message is delivered to the chain targeted by the attacker.
bytes memory victimMsg = "relay my message";

uint256 victimBalancePreDelivery = setup.target.refundAddress.balance;

// We will reutilize the compute budget estimated for the attacker to simplify the code here.
// The victim requests their message to be sent.
setup.source.integration.sendMessage{value: computeBudget + 2 * setup.source.wormhole.messageFee()}(
victimMsg, setup.targetChainId, address(setup.target.integration), address(setup.target.refundAddress)
);

// The relayer delivers the victim's message.
// During the delivery process, the forward request injected by the malicious contract is acknowledged.
// The victim's refund address is not called due to this.
genericRelayer(setup.sourceChainId, 2);

// Ensures the message was received.
assertTrue(keccak256(setup.target.integration.getMessage()) == keccak256(victimMsg));
// Here we assert that the victim's refund is safe.
assertTrue(victimBalancePreDelivery < setup.target.refundAddress.balance);
}

Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
if (entries.length > 0) {
// There was a wormhole message produced.
// If the attack is successful this is a forward.
// We'll invoke the relay simulation here and later assert that the attack wasn't successful.
// Relay from target chain to source chain.
genericRelayerProcessLogs(setup.targetChainId, entries);
}
// Assert that the attack wasn't successful.
assertTrue(attackerSourceAddress.balance == 0);
}

function testRedelivery(GasParameters memory gasParams, FeeParameters memory feeParams, bytes memory message)
public
{
Expand Down Expand Up @@ -1219,18 +1320,34 @@ contract TestCoreRelayer is Test {
mapping(bytes32 => ICoreRelayer.TargetDeliveryParametersSingle) pastDeliveries;

function genericRelayer(uint16 chainId, uint8 num) internal {
bytes[] memory encodedVMs = new bytes[](num);
{
// Filters all events to just the wormhole messages.
Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
assertTrue(entries.length >= num);
for (uint256 i = 0; i < num; i++) {
encodedVMs[i] = relayerWormholeSimulator.fetchSignedMessageFromLogs(
entries[i], chainId, address(uint160(uint256(bytes32(entries[i].topics[1]))))
);
}
Vm.Log[] memory entries = truncateRecordedLogs(chainId, num);
genericRelayerProcessLogs(chainId, entries);
}

/**
* Discards wormhole events beyond `num` events.
* Expects at least `num` wormhole events.
*/
function truncateRecordedLogs(uint16 chainId, uint8 num) internal returns (Vm.Log[] memory) {
// Filters all events to just the wormhole messages.
Vm.Log[] memory entries = relayerWormholeSimulator.fetchWormholeMessageFromLog(vm.getRecordedLogs());
// We expect at least `num` events.
assertTrue(entries.length >= num);

Vm.Log[] memory firstEntries = new Vm.Log[](num);
for (uint256 i = 0; i < num; i++) {
firstEntries[i] = entries[i];
}
return firstEntries;
}

function genericRelayerProcessLogs(uint16 chainId, Vm.Log[] memory entries) internal {
bytes[] memory encodedVMs = new bytes[](entries.length);
for (uint256 i = 0; i < encodedVMs.length; i++) {
encodedVMs[i] = relayerWormholeSimulator.fetchSignedMessageFromLogs(
entries[i], chainId, address(uint160(uint256(bytes32(entries[i].topics[1]))))
);
}
IWormhole.VM[] memory parsed = new IWormhole.VM[](encodedVMs.length);
for (uint16 i = 0; i < encodedVMs.length; i++) {
parsed[i] = relayerWormhole.parseVM(encodedVMs[i]);
Expand Down

0 comments on commit 1ce645d

Please sign in to comment.