diff --git a/src/SMv3SessionValidationModule.sol b/src/SMv3SessionValidationModule.sol index 3ffadbd..58d4221 100644 --- a/src/SMv3SessionValidationModule.sol +++ b/src/SMv3SessionValidationModule.sol @@ -45,29 +45,14 @@ contract SMv3SessionValidationModule is ISessionValidationModule { /// @dev ensure the function selector is the a valid IEngine selector bytes4 funcSelector = bytes4(_funcCallData[0:4]); - if ( - funcSelector != IEngine.modifyCollateral.selector - && funcSelector != IEngine.commitOrder.selector - && funcSelector != IEngine.invalidateUnorderedNonces.selector - && funcSelector != EIP7412.fulfillOracleQuery.selector - && funcSelector != IEngine.depositEth.selector - && funcSelector != IEngine.withdrawEth.selector - ) { - revert InvalidSMv3Selector(); - } - /// @dev ensure call value is zero unless calling IEngine.depositEth or EIP7412.fulfillOracleQuery - if (funcSelector == IEngine.depositEth.selector) { - if (callValue == 0) { - revert InvalidCallValue(); - } - } else if (funcSelector == EIP7412.fulfillOracleQuery.selector) { - if (callValue == 0) { - revert InvalidCallValue(); - } - } else if (callValue != 0) { - revert InvalidCallValue(); - } + // sanitize the selector; ensure it is a valid selector + // that can be called on the smv3Engine) + _sanitizeSelector(funcSelector); + + // sanitize the call value; ensure it is zero unless calling + // IEngine.depositEth or EIP7412.fulfillOracleQuery + _sanitizeCallValue(funcSelector, callValue); return sessionKey; } @@ -122,35 +107,58 @@ contract SMv3SessionValidationModule is ISessionValidationModule { data = _op.callData[4 + offset + 32:4 + offset + 32 + length]; } - /// @dev ensure the function selector is the a valid IEngine selector + // define the function selector bytes4 funcSelector = bytes4(data[0:4]); + + // sanitize the selector; ensure it is a valid selector + // that can be called on the smv3Engine) + _sanitizeSelector(funcSelector); + + // sanitize the call value; ensure it is zero unless calling + // IEngine.depositEth or EIP7412.fulfillOracleQuery + _sanitizeCallValue(funcSelector, callValue); + + /// @dev this method of signature validation is out-of-date + /// see https://github.com/OpenZeppelin/openzeppelin-sdk/blob/7d96de7248ae2e7e81a743513ccc617a2e6bba21/packages/lib/contracts/cryptography/ECDSA.sol#L6 + return ECDSA.recover( + ECDSA.toEthSignedMessageHash(_userOpHash), _sessionKeySignature + ) == sessionKey; + } + + /// @notice sanitize the selector to ensure it is a + /// valid selector that can be called on the smv3Engine + /// @param _selector the selector to sanitize + /// @dev will revert if the selector is not valid + function _sanitizeSelector(bytes4 _selector) internal pure { if ( - funcSelector != IEngine.modifyCollateral.selector - && funcSelector != IEngine.commitOrder.selector - && funcSelector != IEngine.invalidateUnorderedNonces.selector - && funcSelector != EIP7412.fulfillOracleQuery.selector - && funcSelector != IEngine.depositEth.selector - && funcSelector != IEngine.withdrawEth.selector + _selector != IEngine.modifyCollateral.selector + && _selector != IEngine.commitOrder.selector + && _selector != IEngine.invalidateUnorderedNonces.selector + && _selector != EIP7412.fulfillOracleQuery.selector + && _selector != IEngine.depositEth.selector + && _selector != IEngine.withdrawEth.selector ) { revert InvalidSMv3Selector(); } + } - /// @dev ensure call value is zero unless calling IEngine.depositEth or EIP7412.fulfillOracleQuery + /// @notice sanitize the call value to ensure it is zero unless calling + /// IEngine.depositEth or EIP7412.fulfillOracleQuery + /// @param _selector the selector to sanitize + /// @dev will revert if the call value is not valid + function _sanitizeCallValue(bytes4 _selector, uint256 _callValue) + internal + pure + { if ( - funcSelector == IEngine.depositEth.selector - || funcSelector == EIP7412.fulfillOracleQuery.selector + _selector == IEngine.depositEth.selector + || _selector == EIP7412.fulfillOracleQuery.selector ) { - if (callValue == 0) { + if (_callValue == 0) { revert InvalidCallValue(); } - } else if (callValue != 0) { + } else if (_callValue != 0) { revert InvalidCallValue(); } - - /// @dev this method of signature validation is out-of-date - /// see https://github.com/OpenZeppelin/openzeppelin-sdk/blob/7d96de7248ae2e7e81a743513ccc617a2e6bba21/packages/lib/contracts/cryptography/ECDSA.sol#L6 - return ECDSA.recover( - ECDSA.toEthSignedMessageHash(_userOpHash), _sessionKeySignature - ) == sessionKey; } }