From a9096c39b0f90a0bac02a24a1478ab79f0f35037 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=A1s=20Pernas=20Maradei?= Date: Mon, 4 Mar 2024 13:55:56 +0100 Subject: [PATCH] channel handshake: split open-init and open-try --- contracts/core/Dispatcher.sol | 65 ++++++++++++---------- contracts/core/UniversalChannelHandler.sol | 40 +++++++------ contracts/examples/Mars.sol | 53 ++++++++---------- contracts/interfaces/IbcDispatcher.sol | 18 ++++-- contracts/interfaces/IbcReceiver.sol | 11 +--- test/Dispatcher.base.t.sol | 37 ++++++++---- test/Dispatcher.proof.t.sol | 11 ++-- test/Dispatcher.t.sol | 44 ++++++--------- test/VirtualChain.sol | 20 ++----- 9 files changed, 148 insertions(+), 151 deletions(-) diff --git a/contracts/core/Dispatcher.sol b/contracts/core/Dispatcher.sol index 9ed41c5b..37d7a0c3 100644 --- a/contracts/core/Dispatcher.sol +++ b/contracts/core/Dispatcher.sol @@ -68,13 +68,35 @@ contract Dispatcher is IbcDispatcher, IbcEventsEmitter, Ownable, Ibc { } /** - * This func is called by a 'relayer' on behalf of a dApp. The dApp should be implements IbcChannelHandler. - * The dApp should implement the onOpenIbcChannel method to handle one of the first two channel handshake methods, - * ie. ChanOpenInit or ChanOpenTry. - * If callback succeeds, the dApp should return the selected version, and an emitted event will be relayed to the - * IBC/VIBC hub chain. + * This function is called by a 'relayer' on behalf of a dApp. The dApp should implement IbcChannelHandler's + * onChanOpenInit. If the callback succeeds, the dApp should return the selected version and the emitted event + * will be relayed to the IBC/VIBC hub chain. */ - function openIbcChannel( + function channelOpenInit( + IbcChannelReceiver portAddress, + string calldata version, + ChannelOrder ordering, + bool feeEnabled, + string[] calldata connectionHops, + string calldata counterpartyPortId + ) external { + if (bytes(counterpartyPortId).length == 0) { + revert IBCErrors.invalidCounterPartyPortId(); + } + + string memory selectedVersion = portAddress.onChanOpenInit(version); + + emit ChannelOpenInit( + address(portAddress), selectedVersion, ordering, feeEnabled, connectionHops, counterpartyPortId + ); + } + + /** + * This function is called by a 'relayer' on behalf of a dApp. The dApp should implement IbcChannelHandler's + * onChanOpenTry. If the callback succeeds, the dApp should return the selected version and the emitted event + * will be relayed to the IBC/VIBC hub chain. + */ + function channelOpenTry( IbcChannelReceiver portAddress, CounterParty calldata local, ChannelOrder ordering, @@ -87,18 +109,15 @@ contract Dispatcher is IbcDispatcher, IbcEventsEmitter, Ownable, Ibc { revert IBCErrors.invalidCounterPartyPortId(); } - if (_isChannelOpenTry(counterparty)) { - consensusStateManager.verifyMembership( - proof, - channelProofKey(local.portId, local.channelId), - channelProofValue(ChannelState.TRY_PENDING, ordering, local.version, connectionHops, counterparty) - ); - } + consensusStateManager.verifyMembership( + proof, + channelProofKey(local.portId, local.channelId), + channelProofValue(ChannelState.TRY_PENDING, ordering, local.version, connectionHops, counterparty) + ); - string memory selectedVersion = - portAddress.onOpenIbcChannel(local.version, ordering, feeEnabled, connectionHops, counterparty); + string memory selectedVersion = portAddress.onChanOpenTry(counterparty.version); - emit OpenIbcChannel( + emit ChannelOpenTry( address(portAddress), selectedVersion, ordering, @@ -515,18 +534,4 @@ contract Dispatcher is IbcDispatcher, IbcEventsEmitter, Ownable, Ibc { addr := mload(add(addrBytes, 20)) } } - - // For XXXX => vIBC direction, SC needs to verify the proof of membership of TRY_PENDING - // For vIBC initiated channel, SC doesn't need to verify any proof, and these should be all empty - function _isChannelOpenTry(CounterParty calldata counterparty) internal pure returns (bool open) { - if (counterparty.channelId == bytes32(0) && bytes(counterparty.version).length == 0) { - open = false; - // ChanOpenInit with unknow conterparty - } else if (counterparty.channelId != bytes32(0) && bytes(counterparty.version).length != 0) { - // this is the ChanOpenTry; counterparty must not be zero-value - open = true; - } else { - revert IBCErrors.invalidCounterParty(); - } - } } diff --git a/contracts/core/UniversalChannelHandler.sol b/contracts/core/UniversalChannelHandler.sol index 058f6cb0..25e7f39b 100644 --- a/contracts/core/UniversalChannelHandler.sol +++ b/contracts/core/UniversalChannelHandler.sol @@ -33,7 +33,6 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW { dispatcher.closeIbcChannel(channelId); } - // IBC callback functions function onConnectIbcChannel(bytes32 channelId, bytes32, string calldata counterpartyVersion) external onlyIbcDispatcher @@ -149,23 +148,28 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW { mwStackAddrs[mwBitmap] = mwAddrs; } - function onOpenIbcChannel( - string calldata version, - ChannelOrder, - bool, - string[] calldata, - CounterParty calldata counterparty - ) external view onlyIbcDispatcher returns (string memory selectedVersion) { - if (counterparty.channelId == bytes32(0)) { - // ChanOpenInit - if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) { - revert UnsupportedVersion(); - } - } else { - // ChanOpenTry - if (keccak256(abi.encodePacked(counterparty.version)) != keccak256(abi.encodePacked(VERSION))) { - revert UnsupportedVersion(); - } + // IBC callback functions + function onChanOpenInit(string calldata version) + external + view + onlyIbcDispatcher + returns (string memory selectedVersion) + { + return _openChannel(version); + } + + function onChanOpenTry(string calldata counterpartyVersion) + external + view + onlyIbcDispatcher + returns (string memory selectedVersion) + { + return _openChannel(counterpartyVersion); + } + + function _openChannel(string calldata version) private pure returns (string memory selectedVersion) { + if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) { + revert UnsupportedVersion(); } return VERSION; } diff --git a/contracts/examples/Mars.sol b/contracts/examples/Mars.sol index 41e2f70b..771aa48d 100644 --- a/contracts/examples/Mars.sol +++ b/contracts/examples/Mars.sol @@ -81,39 +81,30 @@ contract Mars is IbcReceiverBase, IbcReceiver { dispatcher.sendPacket(channelId, bytes(message), timeoutTimestamp); } - function onOpenIbcChannel( - string calldata version, - ChannelOrder, - bool, - string[] calldata, - CounterParty calldata counterparty - ) external view onlyIbcDispatcher returns (string memory selectedVersion) { - if (bytes(counterparty.portId).length <= 8) { - revert IBCErrors.invalidCounterPartyPortId(); - } - /** - * Version selection is determined by if the callback is invoked on behalf of ChanOpenInit or ChanOpenTry. - * ChanOpenInit: self version should be provided whereas the counterparty version is empty. - * ChanOpenTry: counterparty version should be provided whereas the self version is empty. - * In both cases, the selected version should be in the supported versions list. - */ - bool foundVersion = false; - selectedVersion = - keccak256(abi.encodePacked(version)) == keccak256(abi.encodePacked("")) ? counterparty.version : version; + function onChanOpenInit(string calldata version) + external + view + onlyIbcDispatcher + returns (string memory selectedVersion) + { + return _openChannel(version); + } + + function onChanOpenTry(string calldata counterpartyVersion) + external + view + onlyIbcDispatcher + returns (string memory selectedVersion) + { + return _openChannel(counterpartyVersion); + } + + function _openChannel(string calldata version) private view returns (string memory selectedVersion) { for (uint256 i = 0; i < supportedVersions.length; i++) { - if (keccak256(abi.encodePacked(selectedVersion)) == keccak256(abi.encodePacked(supportedVersions[i]))) { - foundVersion = true; - break; + if (keccak256(abi.encodePacked(version)) == keccak256(abi.encodePacked(supportedVersions[i]))) { + return version; } } - if (!foundVersion) revert UnsupportedVersion(); - // if counterpartyVersion is not empty, then it must be the same foundVersion - if (keccak256(abi.encodePacked(counterparty.version)) != keccak256(abi.encodePacked(""))) { - if (keccak256(abi.encodePacked(counterparty.version)) != keccak256(abi.encodePacked(selectedVersion))) { - revert VersionMismatch(); - } - } - - return selectedVersion; + revert UnsupportedVersion(); } } diff --git a/contracts/interfaces/IbcDispatcher.sol b/contracts/interfaces/IbcDispatcher.sol index 3c30210e..a13fcd6a 100644 --- a/contracts/interfaces/IbcDispatcher.sol +++ b/contracts/interfaces/IbcDispatcher.sol @@ -23,14 +23,13 @@ interface IbcPacketSender { * Other features are implemented as callback methods in the IbcReceiver interface. */ interface IbcDispatcher is IbcPacketSender { - function openIbcChannel( + function channelOpenInit( IbcChannelReceiver portAddress, - CounterParty calldata self, + string calldata version, ChannelOrder ordering, bool feeEnabled, string[] calldata connectionHops, - CounterParty calldata counterparty, - Ics23Proof calldata proof + string calldata counterpartyPortId ) external; function closeIbcChannel(bytes32 channelId) external; @@ -46,7 +45,16 @@ interface IbcEventsEmitter { // // channel events // - event OpenIbcChannel( + event ChannelOpenInit( + address indexed portAddress, + string version, + ChannelOrder ordering, + bool feeEnabled, + string[] connectionHops, + string counterpartyPortId + ); + + event ChannelOpenTry( address indexed portAddress, string version, ChannelOrder ordering, diff --git a/contracts/interfaces/IbcReceiver.sol b/contracts/interfaces/IbcReceiver.sol index 5c655f93..3191ae1f 100644 --- a/contracts/interfaces/IbcReceiver.sol +++ b/contracts/interfaces/IbcReceiver.sol @@ -12,13 +12,9 @@ import {ChannelOrder, CounterParty, IbcPacket, AckPacket} from "../libs/Ibc.sol" * handshake callbacks. */ interface IbcChannelReceiver { - function onOpenIbcChannel( - string calldata version, - ChannelOrder ordering, - bool feeEnabled, - string[] calldata connectionHops, - CounterParty calldata counterparty - ) external returns (string memory selectedVersion); + function onChanOpenInit(string calldata version) external returns (string memory selectedVersion); + + function onChanOpenTry(string calldata counterpartyVersion) external returns (string memory selectedVersion); function onConnectIbcChannel(bytes32 channelId, bytes32 counterpartyChannelId, string calldata counterpartyVersion) external; @@ -53,7 +49,6 @@ contract IbcReceiverBase is Ownable { error notIbcDispatcher(); error UnsupportedVersion(); - error VersionMismatch(); error ChannelNotFound(); /** diff --git a/test/Dispatcher.base.t.sol b/test/Dispatcher.base.t.sol index 9d143199..9a6c5ea4 100644 --- a/test/Dispatcher.base.t.sol +++ b/test/Dispatcher.base.t.sol @@ -65,35 +65,50 @@ contract Base is IbcEventsEmitter, ProofBase { // ⬇️ IBC functions for testing /** - * @dev Step-1/2 of the 4-step handshake to open an IBC channel. + * @dev Step-1 of the 4-step handshake to open an IBC channel. * @param le Local end settings for the channel. * @param re Remote end settings for the channel. * @param s Channel handshake settings. * @param expPass Expected pass status of the operation. * If expPass is false, `vm.expectRevert` should be called before this function. */ - function openChannel(LocalEnd memory le, CounterParty memory re, ChannelHandshakeSetting memory s, bool expPass) + function channelOpenInit(LocalEnd memory le, CounterParty memory re, ChannelHandshakeSetting memory s, bool expPass) public { - CounterParty memory cp; - cp.portId = re.portId; - if (!s.localInitiate) { - cp.channelId = re.channelId; - cp.version = re.version; + if (expPass) { + vm.expectEmit(true, true, true, true); + emit ChannelOpenInit( + address(le.receiver), le.versionExpected, s.ordering, s.feeEnabled, le.connectionHops, re.portId + ); } + dispatcher.channelOpenInit(le.receiver, le.versionCall, s.ordering, s.feeEnabled, le.connectionHops, re.portId); + } + + /** + * @dev Step-2 of the 4-step handshake to open an IBC channel. + * @param le Local end settings for the channel. + * @param re Remote end settings for the channel. + * @param s Channel handshake settings. + * @param expPass Expected pass status of the operation. + * If expPass is false, `vm.expectRevert` should be called before this function. + */ + function channelOpenTry(LocalEnd memory le, CounterParty memory re, ChannelHandshakeSetting memory s, bool expPass) + public + { if (expPass) { vm.expectEmit(true, true, true, true); - emit OpenIbcChannel( + emit ChannelOpenTry( address(le.receiver), le.versionExpected, s.ordering, s.feeEnabled, le.connectionHops, - cp.portId, - cp.channelId + re.portId, + re.channelId ); } - dispatcher.openIbcChannel( + CounterParty memory cp = CounterParty(re.portId, re.channelId, re.version); + dispatcher.channelOpenTry( le.receiver, CounterParty(le.portId, le.channelId, le.versionCall), s.ordering, diff --git a/test/Dispatcher.proof.t.sol b/test/Dispatcher.proof.t.sol index b1a73ee5..84b6a89a 100644 --- a/test/Dispatcher.proof.t.sol +++ b/test/Dispatcher.proof.t.sol @@ -31,21 +31,20 @@ contract DispatcherIbcWithRealProofs is IbcEventsEmitter, ProofBase { } function test_ibc_channel_open_init() public { - CounterParty memory counterparty = CounterParty(ch1.portId, bytes32(0), ""); - vm.expectEmit(true, true, true, true); - emit OpenIbcChannel(address(mars), "1.0", ChannelOrder.NONE, false, connectionHops1, ch1.portId, bytes32(0)); + emit ChannelOpenInit(address(mars), "1.0", ChannelOrder.NONE, false, connectionHops1, ch1.portId); + // since this is open chann init, the proof is not used. so use an invalid one - dispatcher.openIbcChannel(mars, ch1, ChannelOrder.NONE, false, connectionHops1, counterparty, invalidProof); + dispatcher.channelOpenInit(mars, ch1.version, ChannelOrder.NONE, false, connectionHops1, ch1.portId); } function test_ibc_channel_open_try() public { Ics23Proof memory proof = load_proof("/test/payload/channel_try_pending_proof.hex"); vm.expectEmit(true, true, true, true); - emit OpenIbcChannel(address(mars), "1.0", ChannelOrder.NONE, false, connectionHops1, ch0.portId, ch0.channelId); + emit ChannelOpenTry(address(mars), "1.0", ChannelOrder.NONE, false, connectionHops1, ch0.portId, ch0.channelId); - dispatcher.openIbcChannel(mars, ch1, ChannelOrder.NONE, false, connectionHops1, ch0, proof); + dispatcher.channelOpenTry(mars, ch1, ChannelOrder.NONE, false, connectionHops1, ch0, proof); } function test_ibc_channel_ack() public { diff --git a/test/Dispatcher.t.sol b/test/Dispatcher.t.sol index c15d5cf8..8fbcaa2c 100644 --- a/test/Dispatcher.t.sol +++ b/test/Dispatcher.t.sol @@ -32,7 +32,7 @@ contract ChannelHandshakeTest is Base { le.versionCall = versions[j]; le.versionExpected = versions[j]; // remoteEnd has no channelId or version if localEnd is the initiator - openChannel(le, re, settings[i], true); + channelOpenInit(le, re, settings[i], true); } } } @@ -49,11 +49,11 @@ contract ChannelHandshakeTest is Base { le.versionCall = versions[j]; le.versionExpected = versions[j]; // remoteEnd version is used - openChannel(le, re, settings[i], true); + channelOpenInit(le, re, settings[i], true); // auto version selection le.versionCall = ""; - openChannel(le, re, settings[i], true); + channelOpenTry(le, re, settings[i], true); } } } @@ -69,30 +69,14 @@ contract ChannelHandshakeTest is Base { le.versionCall = versions[j]; le.versionExpected = versions[j]; re.version = versions[j]; - openChannel(le, re, settings[i], true); + channelOpenInit(le, re, settings[i], true); + channelOpenTry(le, re, settings[i], true); connectChannel(le, re, settings[i], false, true); } } } - function test_openChannel_receiver_fail_versionMismatch() public { - ChannelHandshakeSetting[4] memory settings = createSettings(false, true); - string[2] memory versions = ["1.0", "2.0"]; - for (uint256 i = 0; i < settings.length; i++) { - for (uint256 j = 0; j < versions.length; j++) { - LocalEnd memory le = _local; - CounterParty memory re = _remote; - re.version = versions[j]; - // always select the wrong version - bool isVersionOne = keccak256(abi.encodePacked(versions[j])) == keccak256(abi.encodePacked("1.0")); - le.versionCall = isVersionOne ? "2.0" : "1.0"; - vm.expectRevert(IbcReceiverBase.VersionMismatch.selector); - openChannel(le, re, settings[i], false); - } - } - } - - function test_openChannel_initiator_fail_unsupportedVersion() public { + function test_channelOpenInit_fail_unsupportedVersion() public { ChannelHandshakeSetting[4] memory settings = createSettings(true, true); string[2] memory versions = ["", "xxxxxxx"]; for (uint256 i = 0; i < settings.length; i++) { @@ -102,13 +86,13 @@ contract ChannelHandshakeTest is Base { le.versionCall = versions[j]; le.versionExpected = versions[j]; vm.expectRevert(IbcReceiverBase.UnsupportedVersion.selector); - openChannel(le, re, settings[i], false); + channelOpenInit(le, re, settings[i], false); } } } function test_openChannel_receiver_fail_invalidProof() public { - // When localEnd initiates, no proof verification is done in openIbcChannel + // When localEnd initiates, no proof verification is done in channelOpenTry ChannelHandshakeSetting[4] memory settings = createSettings(false, false); string[1] memory versions = ["1.0"]; for (uint256 i = 0; i < settings.length; i++) { @@ -117,8 +101,9 @@ contract ChannelHandshakeTest is Base { CounterParty memory re = _remote; le.versionCall = versions[j]; le.versionExpected = versions[j]; + vm.expectRevert(DummyConsensusStateManager.InvalidDummyMembershipProof.selector); - openChannel(le, re, settings[i], false); + channelOpenTry(le, re, settings[i], false); } } } @@ -132,7 +117,8 @@ contract ChannelHandshakeTest is Base { LocalEnd memory le = _local; CounterParty memory re = _remote; // no remote version applied in openChannel - openChannel(le, re, settings[i], true); + channelOpenInit(le, re, settings[i], true); + channelOpenTry(le, re, settings[i], true); re.version = versions[j]; vm.expectRevert(IbcReceiverBase.UnsupportedVersion.selector); connectChannel(le, re, settings[i], false, false); @@ -149,7 +135,8 @@ contract ChannelHandshakeTest is Base { LocalEnd memory le = _local; CounterParty memory re = _remote; // no remote version applied in openChannel - openChannel(le, re, settings[i], true); + channelOpenInit(le, re, settings[i], true); + channelOpenTry(le, re, settings[i], true); re.version = versions[j]; settings[i].proof = invalidProof; vm.expectRevert(DummyConsensusStateManager.InvalidDummyMembershipProof.selector); @@ -211,7 +198,8 @@ contract ChannelOpenTestBase is Base { _local = LocalEnd(mars, portId, channelId, connectionHops, "1.0", "1.0"); _remote = CounterParty("eth2.7E5F4552091A69125d5DfCb7b8C2659029395Bdf", "channel-2", "1.0"); - openChannel(_local, _remote, setting, true); + channelOpenInit(_local, _remote, setting, true); + channelOpenTry(_local, _remote, setting, true); connectChannel(_local, _remote, setting, false, true); } } diff --git a/test/VirtualChain.sol b/test/VirtualChain.sol index 70ca30a9..c022412f 100644 --- a/test/VirtualChain.sol +++ b/test/VirtualChain.sol @@ -153,25 +153,17 @@ contract VirtualChain is Test, IbcEventsEmitter { if (expPass) { vm.expectEmit(true, true, true, true); - emit OpenIbcChannel( + emit ChannelOpenInit( address(localEnd), setting.version, setting.ordering, setting.feeEnabled, connectionHops, - remoteChain.portIds(address(remoteEnd)), - bytes32(0) + remoteChain.portIds(address(remoteEnd)) ); } - dispatcher.openIbcChannel( - localEnd, - CounterParty(setting.portId, setting.channelId, setting.version), - setting.ordering, - setting.feeEnabled, - connectionHops, - // counterparty channelId and version are not known at this point - CounterParty(cpPortId, bytes32(0), ""), - setting.proof + dispatcher.channelOpenInit( + localEnd, setting.version, setting.ordering, setting.feeEnabled, connectionHops, cpPortId ); } @@ -193,7 +185,7 @@ contract VirtualChain is Test, IbcEventsEmitter { if (expPass) { vm.expectEmit(true, true, true, true); - emit OpenIbcChannel( + emit ChannelOpenTry( address(localEnd), setting.version, setting.ordering, @@ -203,7 +195,7 @@ contract VirtualChain is Test, IbcEventsEmitter { cpChanId ); } - dispatcher.openIbcChannel( + dispatcher.channelOpenTry( localEnd, CounterParty(setting.portId, setting.channelId, setting.version), setting.ordering,