Skip to content

Commit

Permalink
Raunak/channel handler fixes (#86)
Browse files Browse the repository at this point in the history
* update channel handlers to be more aligned with ibc-spec

* change channelIdentifier -> channelId
  • Loading branch information
RnkSngh authored Apr 15, 2024
1 parent 06a042f commit 8938907
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 58 deletions.
25 changes: 20 additions & 5 deletions contracts/core/Dispatcher.sol
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher {
revert IBCErrors.invalidCounterPartyPortId();
}

// Have to encode here to avoid stack-too-deep error
bytes memory chanOpenInitArgs = abi.encode(ordering, connectionHops, counterpartyPortId, version);
(bool success, bytes memory data) =
_callIfContract(msg.sender, abi.encodeWithSelector(IbcChannelReceiver.onChanOpenInit.selector, version));
_callIfContract(msg.sender, bytes.concat(IbcChannelReceiver.onChanOpenInit.selector, chanOpenInitArgs));

if (success) {
emit ChannelOpenInit(
Expand Down Expand Up @@ -150,7 +152,16 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher {

address receiver = _getAddressFromPort(local.portId);
(bool success, bytes memory data) = _callIfContract(
receiver, abi.encodeWithSelector(IbcChannelReceiver.onChanOpenTry.selector, counterparty.version)
receiver,
abi.encodeWithSelector(
IbcChannelReceiver.onChanOpenTry.selector,
ordering,
connectionHops,
local.channelId,
counterparty.portId,
counterparty.channelId,
counterparty.version
)
);

if (success) {
Expand Down Expand Up @@ -189,7 +200,9 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher {
address receiver = _getAddressFromPort(local.portId);
(bool success, bytes memory data) = _callIfContract(
receiver,
abi.encodeWithSelector(IbcChannelReceiver.onChanOpenAck.selector, local.channelId, counterparty.version)
abi.encodeWithSelector(
IbcChannelReceiver.onChanOpenAck.selector, local.channelId, counterparty.channelId, counterparty.version
)
);

if (success) {
Expand Down Expand Up @@ -282,7 +295,8 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher {
//
// // confirm with dApp by calling its callback
// IbcChannelReceiver reciever = IbcChannelReceiver(portAddress);
// reciever.onCloseIbcChannel(channelId, channel.counterpartyPortId, channel.counterpartyChannelId);
// reciever.onCloseIbcChannel(channelId, channel.counterpartyPortId,
// channel.counterpartyChannelId);
// delete _portChannelMap[portAddress][channelId];
// emit CloseIbcChannel(portAddress, channelId);
// }
Expand Down Expand Up @@ -541,7 +555,8 @@ contract Dispatcher is OwnableUpgradeable, UUPSUpgradeable, IDispatcher {
_nextSequenceSend[address(portAddress)][local.channelId] = 1;
_nextSequenceRecv[address(portAddress)][local.channelId] = 1;
_nextSequenceAck[address(portAddress)][local.channelId] = 1;
_channelIdToConnection[local.channelId] = connectionHops[0]; // Set channel to connection mapping for finding
_channelIdToConnection[local.channelId] = connectionHops[0]; // Set channel to connection mapping for
// finding
}

// Returns the result of the call if no revert, otherwise returns the error if thrown.
Expand Down
44 changes: 26 additions & 18 deletions contracts/core/UniversalChannelHandler.sol
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW {
ChannelOrder ordering,
bool feeEnabled,
string[] calldata connectionHops,
string calldata counterpartyPortId
string calldata counterpartyPortIdentifier
) external onlyOwner {
dispatcher.channelOpenInit(version, ordering, feeEnabled, connectionHops, counterpartyPortId);
dispatcher.channelOpenInit(version, ordering, feeEnabled, connectionHops, counterpartyPortIdentifier);
}

function sendUniversalPacket(
Expand Down Expand Up @@ -151,16 +151,7 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW {
mwStackAddrs[mwBitmap] = mwAddrs;
}

// IBC callback functions
function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher {
_connectChannel(channelId, counterpartyVersion);
}

function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher {
_connectChannel(channelId, counterpartyVersion);
}

function onChanOpenInit(string calldata version)
function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata version)
external
view
onlyIbcDispatcher
Expand All @@ -169,23 +160,40 @@ contract UniversalChannelHandler is IbcReceiverBase, IbcUniversalChannelMW {
return _openChannel(version);
}

function onChanOpenTry(string calldata counterpartyVersion)
// solhint-disable-next-line ordering
function onChanOpenTry(
ChannelOrder,
string[] memory,
bytes32 channelId,
string memory,
bytes32,
string calldata counterpartyVersion
) external onlyIbcDispatcher returns (string memory selectedVersion) {
return _connectChannel(channelId, counterpartyVersion);
}

// IBC callback functions
function onChanOpenAck(bytes32 channelId, bytes32, string calldata counterpartyVersion)
external
view
onlyIbcDispatcher
returns (string memory selectedVersion)
{
return _openChannel(counterpartyVersion);
_connectChannel(channelId, counterpartyVersion);
}

function _connectChannel(bytes32 channelId, string calldata version) private {
function onChanOpenConfirm(bytes32 channelId) external onlyIbcDispatcher {}

function _connectChannel(bytes32 channelId, string calldata version)
internal
returns (string memory checkedVersion)
{
if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) {
revert UnsupportedVersion();
}
connectedChannels.push(channelId);
checkedVersion = version;
}

function _openChannel(string calldata version) private pure returns (string memory selectedVersion) {
function _openChannel(string calldata version) internal pure returns (string memory selectedVersion) {
if (keccak256(abi.encodePacked(version)) != keccak256(abi.encodePacked(VERSION))) {
revert UnsupportedVersion();
}
Expand Down
48 changes: 31 additions & 17 deletions contracts/examples/Mars.sol
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,7 @@ contract Mars is IbcReceiverBase, IbcReceiver {
dispatcher.sendPacket(channelId, bytes(message), timeoutTimestamp);
}

function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external virtual onlyIbcDispatcher {
_connectChannel(channelId, counterpartyVersion);
}

function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external onlyIbcDispatcher {
_connectChannel(channelId, counterpartyVersion);
}

function onChanOpenInit(string calldata version)
function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata version)
external
view
virtual
Expand All @@ -98,22 +90,37 @@ contract Mars is IbcReceiverBase, IbcReceiver {
return _openChannel(version);
}

function onChanOpenTry(string calldata counterpartyVersion)
// solhint-disable-next-line ordering
function onChanOpenTry(
ChannelOrder,
string[] memory,
bytes32 channelId,
string memory,
bytes32,
string calldata counterpartyVersion
) external virtual onlyIbcDispatcher returns (string memory selectedVersion) {
return _connectChannel(channelId, counterpartyVersion);
}

function onChanOpenAck(bytes32 channelId, bytes32, string calldata counterpartyVersion)
external
view
virtual
onlyIbcDispatcher
returns (string memory selectedVersion)
{
return _openChannel(counterpartyVersion);
_connectChannel(channelId, counterpartyVersion);
}

function _connectChannel(bytes32 channelId, string calldata counterpartyVersion) private {
function onChanOpenConfirm(bytes32 channelId) external onlyIbcDispatcher {}

function _connectChannel(bytes32 channelId, string calldata counterpartyVersion)
private
returns (string memory version)
{
// ensure negotiated version is supported
for (uint256 i = 0; i < supportedVersions.length; i++) {
if (keccak256(abi.encodePacked(counterpartyVersion)) == keccak256(abi.encodePacked(supportedVersions[i]))) {
connectedChannels.push(channelId);
return;
return counterpartyVersion;
}
}
revert UnsupportedVersion();
Expand All @@ -137,7 +144,14 @@ contract RevertingStringMars is Mars {
constructor(IbcDispatcher _dispatcher) Mars(_dispatcher) {}

// solhint-disable-next-line
function onChanOpenInit(string calldata) external view override onlyIbcDispatcher returns (string memory) {
function onChanOpenInit(ChannelOrder, string[] calldata, string calldata, string calldata)
external
view
virtual
override
onlyIbcDispatcher
returns (string memory selectedVersion)
{
// solhint-disable-next-line
require(false, "open ibc channel is reverting");
return "";
Expand All @@ -151,7 +165,7 @@ contract RevertingStringMars is Mars {
}

// solhint-disable-next-line
function onChanOpenAck(bytes32, string calldata) external view override onlyIbcDispatcher {
function onChanOpenAck(bytes32, bytes32, string calldata) external view override onlyIbcDispatcher {
// solhint-disable-next-line
require(false, "connect ibc channel is reverting");
}
Expand Down
33 changes: 24 additions & 9 deletions contracts/interfaces/IbcReceiver.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,31 @@ import {ChannelOrder, ChannelEnd, IbcPacket, AckPacket} from "../libs/Ibc.sol";
* handshake callbacks.
*/
interface IbcChannelReceiver {
function onChanOpenInit(string calldata version) external returns (string memory selectedVersion);

function onChanOpenTry(string calldata counterpartyVersion) external returns (string memory selectedVersion);

function onChanOpenAck(bytes32 channelId, string calldata counterpartyVersion) external;

function onChanOpenConfirm(bytes32 channelId, string calldata counterpartyVersion) external;

function onCloseIbcChannel(bytes32 channelId, string calldata counterpartyPortId, bytes32 counterpartyChannelId)
function onChanOpenInit(
ChannelOrder order,
string[] calldata connectionHops,
string calldata counterpartyPortIdentifier,
string calldata version
) external returns (string memory selectedVersion);

function onChanOpenTry(
ChannelOrder order,
string[] memory connectionHops,
bytes32 channelId,
string memory counterpartyPortIdentifier,
bytes32 counterpartychannelId,
string memory counterpartyVersion
) external returns (string memory selectedVersion);

function onChanOpenAck(bytes32 channelId, bytes32 counterpartychannelId, string calldata counterpartyVersion)
external;

function onChanOpenConfirm(bytes32 channelId) external;
function onCloseIbcChannel(
bytes32 channelId,
string calldata counterpartyPortIdentifier,
bytes32 counterpartyChannelId
) external;
}

/**
Expand Down
1 change: 1 addition & 0 deletions test/Ibc.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pragma solidity ^0.8.0;

import "../contracts/libs/Ibc.sol";
import "forge-std/Test.sol";
import {IbcChannelReceiver} from "../contracts/interfaces/IbcReceiver.sol";

contract IbcTest is Test {
function test_packet_commitment_proof_key() public {
Expand Down
15 changes: 8 additions & 7 deletions test/VirtualChain.sol
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest {
vm.prank(address(remoteChain));
remoteChain.channelOpenTry(remoteEnd, this, localEnd, setting, true); // step-2

vm.prank(address(this));
this.channelOpenConfirm(localEnd, remoteChain, remoteEnd, setting, true); // step-3

vm.prank(address(remoteChain));
remoteChain.channelOpenAck(remoteEnd, this, localEnd, setting, true); // step-4
this.channelOpenAck(localEnd, remoteChain, remoteEnd, setting, true); // step-4

vm.prank(address(this));
remoteChain.channelOpenConfirm(remoteEnd, this, localEnd, setting, true); // step-3
}

function channelOpenInit(
Expand Down Expand Up @@ -178,6 +178,9 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest {
ChannelSetting memory setting,
bool expPass
) external {
bytes32 chanId = channelIds[address(localEnd)][address(remoteEnd)];
require(chanId != bytes32(0), "channelOpenTry: channel does not exist");

bytes32 cpChanId = remoteChain.channelIds(address(remoteEnd), address(localEnd));
require(cpChanId != bytes32(0), "channelOpenTry: channel does not exist");

Expand All @@ -201,9 +204,7 @@ contract VirtualChain is Test, IbcEventsEmitter, TestUtilsTest {
}
dispatcherProxy.channelOpenTry(
ChannelEnd(
IbcUtils.addressToPortId(dispatcherProxy.portPrefix(), address(localEnd)),
setting.channelId,
setting.version
IbcUtils.addressToPortId(dispatcherProxy.portPrefix(), address(localEnd)), chanId, setting.version
),
setting.ordering,
setting.feeEnabled,
Expand Down
4 changes: 2 additions & 2 deletions test/universal.channel.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ contract UniversalChannelTest is Base {
function assert_channel(VirtualChain vc1, VirtualChain vc2, ChannelSetting memory setting) internal {
bytes32 channelId1 = vc1.channelIds(address(vc1.ucHandler()), address(vc2.ucHandler()));
bytes32 channelId2 = vc2.channelIds(address(vc2.ucHandler()), address(vc1.ucHandler()));
assertEq(vc1.ucHandler().connectedChannels(0), channelId1);
assertEq(vc2.ucHandler().connectedChannels(0), channelId2);
assertEq(vc1.ucHandler().connectedChannels(0), channelId1, "channels not equal 1");
assertEq(vc2.ucHandler().connectedChannels(0), channelId2, "channels not equal 2");

Channel memory channel1 = vc1.dispatcherProxy().getChannel(address(vc1.ucHandler()), channelId1);
Channel memory channel2 = vc2.dispatcherProxy().getChannel(address(vc2.ucHandler()), channelId2);
Expand Down

0 comments on commit 8938907

Please sign in to comment.