diff --git a/contracts/core/Dispatcher.sol b/contracts/core/Dispatcher.sol index aa22b5cd..0979cb7d 100644 --- a/contracts/core/Dispatcher.sol +++ b/contracts/core/Dispatcher.sol @@ -113,8 +113,10 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher { revert IBCErrors.invalidCounterPartyPortId(); } + // Have to encode here to avoid stack-too-deep error + bytes memory chanOpenInitArgs = abi.encode(ordering, connectionHops, counterpartyPortId, version); (bool success, bytes memory data) = - _callIfContract(msg.sender, abi.encodeWithSelector(IbcChannelReceiver.onChanOpenInit.selector, version)); + _callIfContract(msg.sender, bytes.concat(IbcChannelReceiver.onChanOpenInit.selector, chanOpenInitArgs)); if (success) { emit ChannelOpenInit( @@ -150,7 +152,16 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher { address receiver = _getAddressFromPort(local.portId); (bool success, bytes memory data) = _callIfContract( - receiver, abi.encodeWithSelector(IbcChannelReceiver.onChanOpenTry.selector, counterparty.version) + receiver, + abi.encodeWithSelector( + IbcChannelReceiver.onChanOpenTry.selector, + ordering, + connectionHops, + local.channelId, + counterparty.portId, + counterparty.channelId, + counterparty.version + ) ); if (success) { @@ -189,7 +200,9 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher { address receiver = _getAddressFromPort(local.portId); (bool success, bytes memory data) = _callIfContract( receiver, - abi.encodeWithSelector(IbcChannelReceiver.onChanOpenAck.selector, local.channelId, counterparty.version) + abi.encodeWithSelector( + IbcChannelReceiver.onChanOpenAck.selector, local.channelId, counterparty.channelId, counterparty.version + ) ); if (success) { @@ -282,7 +295,8 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher { // // // confirm with dApp by calling its callback // IbcChannelReceiver reciever = IbcChannelReceiver(portAddress); - // reciever.onCloseIbcChannel(channelId, channel.counterpartyPortId, channel.counterpartyChannelId); + // reciever.onCloseIbcChannel(channelId, channel.counterpartyPortId, + // channel.counterpartyChannelId); // delete _portChannelMap[portAddress][channelId]; // emit CloseIbcChannel(portAddress, channelId); // } @@ -541,7 +555,8 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher { _nextSequenceSend[address(portAddress)][local.channelId] = 1; _nextSequenceRecv[address(portAddress)][local.channelId] = 1; _nextSequenceAck[address(portAddress)][local.channelId] = 1; - _channelIdToConnection[local.channelId] = connectionHops[0]; // Set channel to connection mapping for finding + _channelIdToConnection[local.channelId] = connectionHops[0]; // Set channel to connection mapping for + // finding } // Returns the result of the call if no revert, otherwise returns the error if thrown. diff --git a/contracts/core/UniversalChannelHandler.sol b/contracts/core/UniversalChannelHandler.sol index c99c527a..ac236963 100644 --- a/contracts/core/UniversalChannelHandler.sol +++ b/contracts/core/UniversalChannelHandler.sol @@ -53,9 +53,9 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW { ChannelOrder ordering, bool feeEnabled, string[] calldata connectionHops, - string calldata counterpartyPortId + string calldata counterpartyPortIdentifier ) external onlyOwner { - dispatcher.channelOpenInit(version, ordering, feeEnabled, connectionHops, counterpartyPortId); + dispatcher.channelOpenInit(version, ordering, feeEnabled, connectionHops, counterpartyPortIdentifier); } function sendUniversalPacket( @@ -151,16 +151,7 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW { mwStackAddrs[mwBitmap] = mwAddrs; } - // IBC callback functions - function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher { - _connectChannel(channelId, counterpartyVersion); - } - - function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher { - _connectChannel(channelId, counterpartyVersion); - } - - function onChanOpenInit(string calldata version) + function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata version) external view onlyIbcDispatcher @@ -169,23 +160,40 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW { return _openChannel(version); } - function onChanOpenTry(string calldata counterpartyVersion) + // solhint-disable-next-line ordering + function onChanOpenTry( + ChannelOrder, + string[] memory, + bytes32 channelId, + string memory, + bytes32, + string calldata counterpartyVersion + ) external onlyIbcDispatcher returns (string memory selectedVersion) { + return _connectChannel(channelId, counterpartyVersion); + } + + // IBC callback functions + function onChanOpenAck(bytes32 channelId, bytes32, string calldata counterpartyVersion) external - view onlyIbcDispatcher - returns (string memory selectedVersion) { - return _openChannel(counterpartyVersion); + _connectChannel(channelId, counterpartyVersion); } - function _connectChannel(bytes32 channelId, string calldata version) private { + function onChanOpenConfirm(bytes32 channelId) external onlyIbcDispatcher {} + + function _connectChannel(bytes32 channelId, string calldata version) + internal + returns (string memory checkedVersion) + { if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) { revert UnsupportedVersion(); } connectedChannels.push(channelId); + checkedVersion = version; } - function _openChannel(string calldata version) private pure returns (string memory selectedVersion) { + function _openChannel(string calldata version) internal pure returns (string memory selectedVersion) { if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) { revert UnsupportedVersion(); } diff --git a/contracts/examples/Mars.sol b/contracts/examples/Mars.sol index 442c4a57..695d6b95 100644 --- a/contracts/examples/Mars.sol +++ b/contracts/examples/Mars.sol @@ -80,15 +80,7 @@ contract Mars is IbcReceiverBase, IbcReceiver { dispatcher.sendPacket(channelId, bytes(message), timeoutTimestamp); } - function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external virtual onlyIbcDispatcher { - _connectChannel(channelId, counterpartyVersion); - } - - function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher { - _connectChannel(channelId, counterpartyVersion); - } - - function onChanOpenInit(string calldata version) + function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata version) external view virtual @@ -98,22 +90,37 @@ contract Mars is IbcReceiverBase, IbcReceiver { return _openChannel(version); } - function onChanOpenTry(string calldata counterpartyVersion) + // solhint-disable-next-line ordering + function onChanOpenTry( + ChannelOrder, + string[] memory, + bytes32 channelId, + string memory, + bytes32, + string calldata counterpartyVersion + ) external virtual onlyIbcDispatcher returns (string memory selectedVersion) { + return _connectChannel(channelId, counterpartyVersion); + } + + function onChanOpenAck(bytes32 channelId, bytes32, string calldata counterpartyVersion) external - view virtual onlyIbcDispatcher - returns (string memory selectedVersion) { - return _openChannel(counterpartyVersion); + _connectChannel(channelId, counterpartyVersion); } - function _connectChannel(bytes32 channelId, string calldata counterpartyVersion) private { + function onChanOpenConfirm(bytes32 channelId) external onlyIbcDispatcher {} + + function _connectChannel(bytes32 channelId, string calldata counterpartyVersion) + private + returns (string memory version) + { // ensure negotiated version is supported for (uint256 i = 0; i < supportedVersions.length; i++) { if (keccak256(abi.encodePacked(counterpartyVersion)) == keccak256(abi.encodePacked(supportedVersions[i]))) { connectedChannels.push(channelId); - return; + return counterpartyVersion; } } revert UnsupportedVersion(); @@ -137,7 +144,14 @@ contract RevertingStringMars is Mars { constructor(IbcDispatcher _dispatcher) Mars(_dispatcher) {} // solhint-disable-next-line - function onChanOpenInit(string calldata) external view override onlyIbcDispatcher returns (string memory) { + function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata) + external + view + virtual + override + onlyIbcDispatcher + returns (string memory selectedVersion) + { // solhint-disable-next-line require(false, "open ibc channel is reverting"); return ""; @@ -151,7 +165,7 @@ contract RevertingStringMars is Mars { } // solhint-disable-next-line - function onChanOpenAck(bytes32, string calldata) external view override onlyIbcDispatcher { + function onChanOpenAck(bytes32, bytes32, string calldata) external view override onlyIbcDispatcher { // solhint-disable-next-line require(false, "connect ibc channel is reverting"); } diff --git a/contracts/interfaces/IbcReceiver.sol b/contracts/interfaces/IbcReceiver.sol index 2aa30af9..a3de6c46 100644 --- a/contracts/interfaces/IbcReceiver.sol +++ b/contracts/interfaces/IbcReceiver.sol @@ -12,16 +12,31 @@ import {ChannelOrder, ChannelEnd, IbcPacket, AckPacket} from "../libs/Ibc.sol"; * handshake callbacks. */ interface IbcChannelReceiver { - function onChanOpenInit(string calldata version) external returns (string memory selectedVersion); - - function onChanOpenTry(string calldata counterpartyVersion) external returns (string memory selectedVersion); - - function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external; - - function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external; - - function onCloseIbcChannel(bytes32 channelId, string calldata counterpartyPortId, bytes32 counterpartyChannelId) + function onChanOpenInit( + ChannelOrder order, + string[] calldata connectionHops, + string calldata counterpartyPortIdentifier, + string calldata version + ) external returns (string memory selectedVersion); + + function onChanOpenTry( + ChannelOrder order, + string[] memory connectionHops, + bytes32 channelId, + string memory counterpartyPortIdentifier, + bytes32 counterpartychannelId, + string memory counterpartyVersion + ) external returns (string memory selectedVersion); + + function onChanOpenAck(bytes32 channelId, bytes32 counterpartychannelId, string calldata counterpartyVersion) external; + + function onChanOpenConfirm(bytes32 channelId) external; + function onCloseIbcChannel( + bytes32 channelId, + string calldata counterpartyPortIdentifier, + bytes32 counterpartyChannelId + ) external; } /** diff --git a/test/Ibc.t.sol b/test/Ibc.t.sol index af494b96..8adb1e99 100644 --- a/test/Ibc.t.sol +++ b/test/Ibc.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "../contracts/libs/Ibc.sol"; import "forge-std/Test.sol"; +import {IbcChannelReceiver} from "../contracts/interfaces/IbcReceiver.sol"; contract IbcTest is Test { function test_packet_commitment_proof_key() public { diff --git a/test/VirtualChain.sol b/test/VirtualChain.sol index cc85572f..da76f70b 100644 --- a/test/VirtualChain.sol +++ b/test/VirtualChain.sol @@ -136,11 +136,11 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest { vm.prank(address(remoteChain)); remoteChain.channelOpenTry(remoteEnd, this, localEnd, setting, true); // step-2 - vm.prank(address(this)); - this.channelOpenConfirm(localEnd, remoteChain, remoteEnd, setting, true); // step-3 - vm.prank(address(remoteChain)); - remoteChain.channelOpenAck(remoteEnd, this, localEnd, setting, true); // step-4 + this.channelOpenAck(localEnd, remoteChain, remoteEnd, setting, true); // step-4 + + vm.prank(address(this)); + remoteChain.channelOpenConfirm(remoteEnd, this, localEnd, setting, true); // step-3 } function channelOpenInit( @@ -178,6 +178,9 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest { ChannelSetting memory setting, bool expPass ) external { + bytes32 chanId = channelIds[address(localEnd)][address(remoteEnd)]; + require(chanId != bytes32(0), "channelOpenTry: channel does not exist"); + bytes32 cpChanId = remoteChain.channelIds(address(remoteEnd), address(localEnd)); require(cpChanId != bytes32(0), "channelOpenTry: channel does not exist"); @@ -201,9 +204,7 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest { } dispatcherProxy.channelOpenTry( ChannelEnd( - IbcUtils.addressToPortId(dispatcherProxy.portPrefix(), address(localEnd)), - setting.channelId, - setting.version + IbcUtils.addressToPortId(dispatcherProxy.portPrefix(), address(localEnd)), chanId, setting.version ), setting.ordering, setting.feeEnabled, diff --git a/test/universal.channel.t.sol b/test/universal.channel.t.sol index b2199a06..abd2b736 100644 --- a/test/universal.channel.t.sol +++ b/test/universal.channel.t.sol @@ -39,8 +39,8 @@ contract UniversalChannelTest is Base { function assert_channel(VirtualChain vc1, VirtualChain vc2, ChannelSetting memory setting) internal { bytes32 channelId1 = vc1.channelIds(address(vc1.ucHandler()), address(vc2.ucHandler())); bytes32 channelId2 = vc2.channelIds(address(vc2.ucHandler()), address(vc1.ucHandler())); - assertEq(vc1.ucHandler().connectedChannels(0), channelId1); - assertEq(vc2.ucHandler().connectedChannels(0), channelId2); + assertEq(vc1.ucHandler().connectedChannels(0), channelId1, "channels not equal 1"); + assertEq(vc2.ucHandler().connectedChannels(0), channelId2, "channels not equal 2"); Channel memory channel1 = vc1.dispatcherProxy().getChannel(address(vc1.ucHandler()), channelId1); Channel memory channel2 = vc2.dispatcherProxy().getChannel(address(vc2.ucHandler()), channelId2);