Skip to content

Commit

Permalink
fix: re-entrancy in Gateway (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstam authored Oct 30, 2023
1 parent c6cd74b commit c19242c
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 17 deletions.
15 changes: 13 additions & 2 deletions contracts/src/SuccinctGateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad
/// @dev A flag that indicates whether the contract is currently making a callback.
bool public isCallback;

/// @dev Protects functions from being re-entered during a fullfil call.
modifier nonReentrant() {
if (
isCallback || verifiedFunctionId != bytes32(0) || verifiedInputHash != bytes32(0)
|| verifiedOutput.length != 0
) {
revert ReentrantFulfill();
}
_;
}

/// @dev Initializes the contract.
/// @param _feeVault The address of the fee vault.
/// @param _timelock The address of the timelock contract.
Expand Down Expand Up @@ -162,7 +173,7 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad
bytes memory _context,
bytes memory _output,
bytes memory _proof
) external {
) external nonReentrant {
// Reconstruct the callback hash.
bytes32 contextHash = keccak256(_context);
bytes32 requestHash = _requestHash(
Expand Down Expand Up @@ -218,7 +229,7 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad
bytes memory _proof,
address _callbackAddress,
bytes memory _callbackData
) external {
) external nonReentrant {
// Compute the input and output hashes.
bytes32 inputHash = sha256(_input);
bytes32 outputHash = sha256(_output);
Expand Down
1 change: 1 addition & 0 deletions contracts/src/interfaces/ISuccinctGateway.sol
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ interface ISuccinctGatewayErrors {
error InvalidCall(bytes32 functionId, bytes input);
error CallFailed(address callbackAddress, bytes callbackData);
error InvalidProof(address verifier, bytes32 inputHash, bytes32 outputHash, bytes proof);
error ReentrantFulfill();
}

interface ISuccinctGateway is ISuccinctGatewayEvents, ISuccinctGatewayErrors {
Expand Down
141 changes: 130 additions & 11 deletions contracts/test/SuccinctGateway.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
ISuccinctGatewayErrors
} from "src/interfaces/ISuccinctGateway.sol";
import {IFunctionRegistry} from "src/interfaces/IFunctionRegistry.sol";
import {TestConsumer, TestFunctionVerifier} from "test/TestUtils.sol";
import {TestConsumer, AttackConsumer, TestFunctionVerifier} from "test/TestUtils.sol";
import {Proxy} from "src/upgrades/Proxy.sol";
import {SuccinctFeeVault} from "src/payments/SuccinctFeeVault.sol";

Expand All @@ -35,7 +35,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
address payable internal sender;
address internal owner;

function setUp() public {
function setUp() public virtual {
// Init variables
timelock = makeAddr("timelock");
guardian = makeAddr("guardian");
Expand All @@ -52,6 +52,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
gateway = address(new Proxy(gatewayImpl, ""));
SuccinctGateway(gateway).initialize(feeVault, timelock, guardian);

// Deploy Verifier
bytes32 functionId;
vm.prank(sender);
(functionId, verifier) = IFunctionRegistry(gateway).deployAndRegisterFunction(
Expand Down Expand Up @@ -250,14 +251,14 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
address callAddress = consumer;
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector, OUTPUT, 0);
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector);
uint32 callGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT();
uint256 fee = DEFAULT_FEE;

// Request
vm.expectEmit(true, true, true, true, gateway);
emit RequestCall(functionId, input, callAddress, callData, callGasLimit, consumer, fee);
TestConsumer(consumer).requestCall{value: fee}(input, callData);
TestConsumer(consumer).requestCall{value: fee}(input);

assertEq(TestConsumer(consumer).handledRequests(0), false);

Expand All @@ -277,14 +278,14 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
address callAddress = consumer;
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector, OUTPUT, 0);
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector);
uint32 callGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT();
uint256 fee = 0;

// Request
vm.expectEmit(true, true, true, true, gateway);
emit RequestCall(functionId, input, callAddress, callData, callGasLimit, consumer, fee);
TestConsumer(consumer).requestCall{value: fee}(input, callData);
TestConsumer(consumer).requestCall{value: fee}(input);

assertEq(TestConsumer(consumer).handledRequests(0), false);

Expand All @@ -307,14 +308,14 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
address callAddress = consumer;
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector, OUTPUT, 0);
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector);
uint32 callGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT();
uint256 fee = DEFAULT_FEE;

// Request
vm.expectEmit(true, true, true, true, gateway);
emit RequestCall(functionId, input, callAddress, callData, callGasLimit, consumer, fee);
TestConsumer(consumer).requestCall{value: fee}(input, callData);
TestConsumer(consumer).requestCall{value: fee}(input);

assertEq(TestConsumer(consumer).handledRequests(0), false);

Expand All @@ -334,7 +335,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
address callAddress = consumer;
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector, OUTPUT, 0);
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector);

// Fulfill
vm.expectRevert();
Expand Down Expand Up @@ -372,7 +373,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
bytes32 functionId = TestConsumer(consumer).FUNCTION_ID();
bytes memory input = INPUT;
address callAddress = consumer;
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector, OUTPUT, 0);
bytes memory callData = abi.encodeWithSelector(TestConsumer.handleCall.selector);
uint32 callGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT();
uint256 fee = DEFAULT_FEE;
address newFeeVault = address(new SuccinctFeeVault());
Expand All @@ -388,7 +389,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
// Request with fee
vm.expectEmit(true, true, true, true, gateway);
emit RequestCall(functionId, input, callAddress, callData, callGasLimit, consumer, fee);
TestConsumer(consumer).requestCall{value: fee}(input, callData);
TestConsumer(consumer).requestCall{value: fee}(input);
}

function test_RevertSetFeeVault_WhenNotGuardian() public {
Expand All @@ -399,3 +400,121 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr
SuccinctGateway(gateway).setFeeVault(newFeeVault);
}
}

contract AttackSuccinctGateway is SuccinctGatewayTest {
address payable internal attackConsumer;

function setUp() public override {
super.setUp();

// Deploy Verifier
bytes32 functionId;
vm.prank(sender);
(functionId, verifier) = IFunctionRegistry(gateway).deployAndRegisterFunction(
owner, type(TestFunctionVerifier).creationCode, "attack-verifier"
);

// Deploy AttackConsumer
attackConsumer = payable(address(new AttackConsumer(gateway, functionId)));

vm.deal(attackConsumer, DEFAULT_FEE);
}

function test_RevertCallbackReenterCallback() public {
bytes memory input = INPUT;
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
bytes32 functionId = AttackConsumer(attackConsumer).FUNCTION_ID();
bytes32 inputHash = INPUT_HASH;
address callbackAddress = attackConsumer;
bytes4 callbackSelector = AttackConsumer.handleCallbackReenterCallback.selector;
uint32 callbackGasLimit = AttackConsumer(attackConsumer).CALLBACK_GAS_LIMIT();
uint256 fee = DEFAULT_FEE;

// Request
vm.prank(sender);
AttackConsumer(attackConsumer).requestCallbackReenterCallback{value: fee}(input);

// Fulfill (test fails this doesn't revert with ReentrantFulfill() error)
SuccinctGateway(gateway).fulfillCallback(
0,
functionId,
inputHash,
callbackAddress,
callbackSelector,
callbackGasLimit,
"",
output,
proof
);
}

function test_RevertCallbackReenterCall() public {
bytes memory input = INPUT;
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
bytes32 functionId = AttackConsumer(attackConsumer).FUNCTION_ID();
bytes32 inputHash = INPUT_HASH;
address callbackAddress = attackConsumer;
bytes4 callbackSelector = AttackConsumer.handleCallbackReenterCall.selector;
uint32 callbackGasLimit = AttackConsumer(attackConsumer).CALLBACK_GAS_LIMIT();
uint256 fee = DEFAULT_FEE;

// Request
vm.prank(sender);
AttackConsumer(attackConsumer).requestCallbackReenterCall{value: fee}(input);

// Fulfill (test fails this doesn't revert with ReentrantFulfill() error)
SuccinctGateway(gateway).fulfillCallback(
0,
functionId,
inputHash,
callbackAddress,
callbackSelector,
callbackGasLimit,
"",
output,
proof
);
}

function test_RevertCallReenterCallback() public {
bytes memory input = INPUT;
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
bytes32 functionId = AttackConsumer(attackConsumer).FUNCTION_ID();
address callAddress = attackConsumer;
bytes memory callData =
abi.encodeWithSelector(AttackConsumer.handleCallReenterCallback.selector);
uint256 fee = DEFAULT_FEE;

// Request
vm.prank(sender);
AttackConsumer(attackConsumer).requestCallReenterCallback{value: fee}(input);

// Fulfill (test fails this doesn't revert with ReentrantFulfill() error)
SuccinctGateway(gateway).fulfillCall(
functionId, input, output, proof, callAddress, callData
);
}

function test_RevertCallReenterCall() public {
bytes memory input = INPUT;
bytes memory output = OUTPUT;
bytes memory proof = PROOF;
bytes32 functionId = AttackConsumer(attackConsumer).FUNCTION_ID();
address callAddress = attackConsumer;
bytes memory callData =
abi.encodeWithSelector(AttackConsumer.handleCallReenterCall.selector);
uint256 fee = DEFAULT_FEE;

// Request
vm.prank(sender);
AttackConsumer(attackConsumer).requestCallReenterCall{value: fee}(input);

// Fulfill (test fails this doesn't revert with ReentrantFulfill() error)
SuccinctGateway(gateway).fulfillCall(
functionId, input, output, proof, callAddress, callData
);
}
}
Loading

0 comments on commit c19242c

Please sign in to comment.