diff --git a/src/SMv3SessionValidationModule.sol b/src/SMv3SessionValidationModule.sol index af5a6e9..0d92961 100644 --- a/src/SMv3SessionValidationModule.sol +++ b/src/SMv3SessionValidationModule.sol @@ -4,7 +4,10 @@ pragma solidity 0.8.18; import {ECDSA} from "src/openzeppelin/ECDSA.sol"; import {IEngine} from "src/kwenta/smv3/IEngine.sol"; import {EIP7412} from "src/kwenta/smv3/EIP7412.sol"; -import {ISessionValidationModule, UserOperation} from "src/biconomy/interfaces/ISessionValidationModule.sol"; +import { + ISessionValidationModule, + UserOperation +} from "src/biconomy/interfaces/ISessionValidationModule.sol"; /** * @title Kwenta Smart Margin v3 Session Validation Module for Biconomy Smart Accounts @@ -32,10 +35,8 @@ contract SMv3SessionValidationModule is ISessionValidationModule { bytes calldata _sessionKeyData, bytes calldata /*_callSpecificData*/ ) external pure override returns (address) { - (address sessionKey, address smv3Engine) = abi.decode( - _sessionKeyData, - (address, address) - ); + (address sessionKey, address smv3Engine) = + abi.decode(_sessionKeyData, (address, address)); /// @dev ensure destinationContract is the smv3Engine if (destinationContract != smv3Engine) { @@ -45,12 +46,12 @@ contract SMv3SessionValidationModule is ISessionValidationModule { /// @dev ensure the function selector is the a valid IEngine selector bytes4 funcSelector = bytes4(_funcCallData[0:4]); if ( - funcSelector != IEngine.modifyCollateral.selector && - funcSelector != IEngine.commitOrder.selector && - funcSelector != IEngine.invalidateUnorderedNonces.selector && - funcSelector != EIP7412.fulfillOracleQuery.selector && - funcSelector != IEngine.depositEth.selector && - funcSelector != IEngine.withdrawEth.selector + funcSelector != IEngine.modifyCollateral.selector + && funcSelector != IEngine.commitOrder.selector + && funcSelector != IEngine.invalidateUnorderedNonces.selector + && funcSelector != EIP7412.fulfillOracleQuery.selector + && funcSelector != IEngine.depositEth.selector + && funcSelector != IEngine.withdrawEth.selector ) { revert InvalidSMv3Selector(); } @@ -92,18 +93,16 @@ contract SMv3SessionValidationModule is ISessionValidationModule { /// or /// `execute_ncC(address,uint256,bytes)` if ( - bytes4(_op.callData[0:4]) != EXECUTE_SELECTOR && - bytes4(_op.callData[0:4]) != EXECUTE_OPTIMIZED_SELECTOR + bytes4(_op.callData[0:4]) != EXECUTE_SELECTOR + && bytes4(_op.callData[0:4]) != EXECUTE_OPTIMIZED_SELECTOR ) { revert InvalidSelector(); } - (address sessionKey, address smv3Engine) = abi.decode( - _sessionKeyData, - (address, address) - ); + (address sessionKey, address smv3Engine) = + abi.decode(_sessionKeyData, (address, address)); - (address destinationContract, uint256 callValue, ) = abi.decode( + (address destinationContract, uint256 callValue,) = abi.decode( _op.callData[4:], // skip selector; already checked (address, uint256, bytes) ); @@ -118,21 +117,20 @@ contract SMv3SessionValidationModule is ISessionValidationModule { bytes calldata data; { uint256 offset = uint256(bytes32(_op.callData[4 + 64:4 + 96])); - uint256 length = uint256( - bytes32(_op.callData[4 + offset:4 + offset + 32]) - ); + uint256 length = + uint256(bytes32(_op.callData[4 + offset:4 + offset + 32])); data = _op.callData[4 + offset + 32:4 + offset + 32 + length]; } /// @dev ensure the function selector is the a valid IEngine selector bytes4 funcSelector = bytes4(data[0:4]); if ( - funcSelector != IEngine.modifyCollateral.selector && - funcSelector != IEngine.commitOrder.selector && - funcSelector != IEngine.invalidateUnorderedNonces.selector && - funcSelector != EIP7412.fulfillOracleQuery.selector && - funcSelector != IEngine.depositEth.selector && - funcSelector != IEngine.withdrawEth.selector + funcSelector != IEngine.modifyCollateral.selector + && funcSelector != IEngine.commitOrder.selector + && funcSelector != IEngine.invalidateUnorderedNonces.selector + && funcSelector != EIP7412.fulfillOracleQuery.selector + && funcSelector != IEngine.depositEth.selector + && funcSelector != IEngine.withdrawEth.selector ) { revert InvalidSMv3Selector(); } @@ -152,10 +150,8 @@ contract SMv3SessionValidationModule is ISessionValidationModule { /// @dev this method of signature validation is out-of-date /// see https://github.com/OpenZeppelin/openzeppelin-sdk/blob/7d96de7248ae2e7e81a743513ccc617a2e6bba21/packages/lib/contracts/cryptography/ECDSA.sol#L6 - return - ECDSA.recover( - ECDSA.toEthSignedMessageHash(_userOpHash), - _sessionKeySignature - ) == sessionKey; + return ECDSA.recover( + ECDSA.toEthSignedMessageHash(_userOpHash), _sessionKeySignature + ) == sessionKey; } } diff --git a/test/SMv3SessionValidationModule.t.sol b/test/SMv3SessionValidationModule.t.sol index b1a75dc..be19c1c 100644 --- a/test/SMv3SessionValidationModule.t.sol +++ b/test/SMv3SessionValidationModule.t.sol @@ -1,8 +1,14 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity 0.8.18; -import {Bootstrap, SMv3SessionValidationModule} from "test/utils/Bootstrap.sol"; -import {UserOperationSignature, UserOperation, UserOperationLib} from "test/utils/UserOperationSignature.sol"; +import { + Bootstrap, SMv3SessionValidationModule +} from "test/utils/Bootstrap.sol"; +import { + UserOperationSignature, + UserOperation, + UserOperationLib +} from "test/utils/UserOperationSignature.sol"; import {IEngine} from "src/kwenta/smv3/IEngine.sol"; import {EIP7412} from "src/kwenta/smv3/EIP7412.sol"; @@ -50,25 +56,18 @@ contract SMv3SessionValidationModuleTest is Bootstrap { destinationContract = smv3Engine; callValue = 0; /// @notice a valid selector for IEngine - funcCallData = abi.encode( - IEngine.modifyCollateral.selector, - bytes32("") - ); + funcCallData = + abi.encode(IEngine.modifyCollateral.selector, bytes32("")); sessionKeyData = abi.encode(sessionKey, smv3Engine); callSpecificData = ""; // validateSessionUserOp params op.callData = abi.encodeWithSelector( - EXECUTE_SELECTOR, - destinationContract, - callValue, - funcCallData + EXECUTE_SELECTOR, destinationContract, callValue, funcCallData ); userOpHash = userOpSignature.hashUserOperation(op); - sessionKeySignature = userOpSignature.getUserOperationSignature( - op, - signerPrivateKey - ); + sessionKeySignature = + userOpSignature.getUserOperationSignature(op, signerPrivateKey); // define array of valid selectors validSelectors.push(IEngine.modifyCollateral.selector); @@ -88,9 +87,8 @@ contract ValidateSessionParams is SMv3SessionValidationModuleTest { if (validSelectors[i] == IEngine.depositEth.selector) { callValue = 1; // valid for depositEth - } else if ( - validSelectors[i] == EIP7412.fulfillOracleQuery.selector - ) { + } else if (validSelectors[i] == EIP7412.fulfillOracleQuery.selector) + { callValue = 1; // valid for fulfillOracleQuery } else { callValue = 0; // invalid for depositEth @@ -98,12 +96,12 @@ contract ValidateSessionParams is SMv3SessionValidationModuleTest { address retSessionKey = smv3SessionValidationModule .validateSessionParams( - destinationContract, - callValue, - funcCallData, - sessionKeyData, - callSpecificData - ); + destinationContract, + callValue, + funcCallData, + sessionKeyData, + callSpecificData + ); assertEq(sessionKey, retSessionKey); } @@ -140,9 +138,8 @@ contract ValidateSessionParams is SMv3SessionValidationModuleTest { if (validSelectors[i] == IEngine.depositEth.selector) { callValue = 0; // i.e. invalid for depositEth - } else if ( - validSelectors[i] == EIP7412.fulfillOracleQuery.selector - ) { + } else if (validSelectors[i] == EIP7412.fulfillOracleQuery.selector) + { callValue = 0; // valid for fulfillOracleQuery } else { callValue = invalid_callValue; @@ -190,28 +187,24 @@ contract ValidateSessionParams is SMv3SessionValidationModuleTest { ) public { vm.assume(invalid_sessionKey != sessionKey); - bytes memory invalid_sessionKeyData = abi.encode( - invalid_sessionKey, - destinationContract - ); + bytes memory invalid_sessionKeyData = + abi.encode(invalid_sessionKey, destinationContract); address retSessionKey = smv3SessionValidationModule .validateSessionParams( - destinationContract, - callValue, - funcCallData, - invalid_sessionKeyData, - callSpecificData - ); + destinationContract, + callValue, + funcCallData, + invalid_sessionKeyData, + callSpecificData + ); assertFalse(retSessionKey == sessionKey); vm.assume(invalid_destinationContract != destinationContract); - invalid_sessionKeyData = abi.encode( - sessionKey, - invalid_destinationContract - ); + invalid_sessionKeyData = + abi.encode(sessionKey, invalid_destinationContract); vm.expectRevert( abi.encodeWithSelector( @@ -237,33 +230,24 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { if (validSelectors[i] == IEngine.depositEth.selector) { callValue = 1; // valid for depositEth - } else if ( - validSelectors[i] == EIP7412.fulfillOracleQuery.selector - ) { + } else if (validSelectors[i] == EIP7412.fulfillOracleQuery.selector) + { callValue = 1; // valid for fulfillOracleQuery } else { callValue = 0; // invalid for depositEth } op.callData = abi.encodeWithSelector( - EXECUTE_SELECTOR, - destinationContract, - callValue, - funcCallData + EXECUTE_SELECTOR, destinationContract, callValue, funcCallData ); userOpHash = userOpSignature.hashUserOperation(op); - sessionKeySignature = userOpSignature.getUserOperationSignature( - op, - signerPrivateKey - ); + sessionKeySignature = + userOpSignature.getUserOperationSignature(op, signerPrivateKey); bool ret = smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); assertTrue(ret); @@ -277,10 +261,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { bytes4 invalid_selector = 0x12345678; op.callData = abi.encodeWithSelector( - invalid_selector, - destinationContract, - 1, - funcCallData + invalid_selector, destinationContract, 1, funcCallData ); vm.expectRevert( @@ -290,10 +271,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ); smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); vm.assume(invalid_destinationContract != destinationContract); @@ -312,10 +290,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ); smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); vm.assume(invalid_callValue != callValue); @@ -326,9 +301,8 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { if (validSelectors[i] == IEngine.depositEth.selector) { callValue = 0; // i.e. invalid for depositEth - } else if ( - validSelectors[i] == EIP7412.fulfillOracleQuery.selector - ) { + } else if (validSelectors[i] == EIP7412.fulfillOracleQuery.selector) + { callValue = 0; // valid for fulfillOracleQuery } else { callValue = invalid_callValue; @@ -348,17 +322,12 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ); smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); } - bytes memory invalid_funcCallData = abi.encode( - invalid_selector, - bytes32("") - ); + bytes memory invalid_funcCallData = + abi.encode(invalid_selector, bytes32("")); op.callData = abi.encodeWithSelector( EXECUTE_SELECTOR, @@ -374,10 +343,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ); smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); } @@ -387,10 +353,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { vm.assume(invalid_userOpHash != userOpHash); bool ret = smv3SessionValidationModule.validateSessionUserOp( - op, - invalid_userOpHash, - sessionKeyData, - sessionKeySignature + op, invalid_userOpHash, sessionKeyData, sessionKeySignature ); assertFalse(ret); @@ -402,16 +365,11 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ) public { vm.assume(invalid_sessionKey != sessionKey); - bytes memory invalid_sessionKeyData = abi.encode( - invalid_sessionKey, - smv3Engine - ); + bytes memory invalid_sessionKeyData = + abi.encode(invalid_sessionKey, smv3Engine); bool isValid = smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - invalid_sessionKeyData, - sessionKeySignature + op, userOpHash, invalid_sessionKeyData, sessionKeySignature ); assertFalse(isValid); @@ -427,10 +385,7 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { ); smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - sessionKeySignature + op, userOpHash, sessionKeyData, sessionKeySignature ); } @@ -444,14 +399,11 @@ contract ValidateSessionUserOp is SMv3SessionValidationModuleTest { // test specific vm.assume(invalid_privateKey != signerPrivateKey); - bytes memory invalidSessionKeySignature = userOpSignature - .getUserOperationSignature(op, invalid_privateKey); + bytes memory invalidSessionKeySignature = + userOpSignature.getUserOperationSignature(op, invalid_privateKey); bool isValid = smv3SessionValidationModule.validateSessionUserOp( - op, - userOpHash, - sessionKeyData, - invalidSessionKeySignature + op, userOpHash, sessionKeyData, invalidSessionKeySignature ); assertFalse(isValid);