diff --git a/contracts/decryptionOracle/DecryptionOracle.sol b/contracts/decryptionOracle/DecryptionOracle.sol index cdca898..d34b203 100644 --- a/contracts/decryptionOracle/DecryptionOracle.sol +++ b/contracts/decryptionOracle/DecryptionOracle.sol @@ -19,8 +19,20 @@ contract DecryptionOracle is UUPSUpgradeable, Ownable2StepUpgradeable { /// @notice Patch version of the contract. uint256 private constant PATCH_VERSION = 0; - event DecryptionRequest(uint256 indexed requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector); + /** + * @dev Event emitted during each decryption request. The off-chain gateway service is listening to it. + */ + event DecryptionRequest( + uint256 indexed counter, + uint256 requestID, + uint256[] cts, + address contractCaller, + bytes4 callbackSelector + ); + /** + * @dev Should revert when msg.sender is not authorized to upgrade the contract. + */ function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {} /// @custom:storage-location erc7201:fhevm.storage.DecryptionOracle @@ -51,22 +63,27 @@ contract DecryptionOracle is UUPSUpgradeable, Ownable2StepUpgradeable { _disableInitializers(); } - function initialize(address _decryptionOracleOwner) external initializer { - __Ownable_init(_decryptionOracleOwner); + /** + * @notice Initializes the contract. + * @param initialOwner Initial owner address. + */ + function initialize(address initialOwner) external initializer { + __Ownable_init(initialOwner); } /** @notice Requests the decryption of n ciphertexts `ctsHandles` with the result returned in a callback. * @notice During callback, msg.sender is called with [callbackSelector,requestID,decrypt(ctsHandles[0]),decrypt(ctsHandles[1]),...,decrypt(ctsHandles[n-1]),signatures] + * @param requestID is the request index generated by the dApp requesting the decryption. * @param ctsHandles is an array of uint256s handles. * @param callbackSelector the callback selector to be called on msg.sender later during fulfilment */ function requestDecryption( + uint256 requestID, uint256[] calldata ctsHandles, bytes4 callbackSelector - ) external virtual returns (uint256 requestID) { + ) external virtual { DecryptionOracleStorage storage $ = _getDecryptionOracleStorage(); - requestID = uint256(keccak256(abi.encodePacked(msg.sender, $.counter))); - emit DecryptionRequest(requestID, ctsHandles, msg.sender, callbackSelector); + emit DecryptionRequest($.counter, requestID, ctsHandles, msg.sender, callbackSelector); $.counter++; } diff --git a/contracts/decryptionOracleLib/DecryptionOracleCaller.sol b/contracts/decryptionOracleLib/DecryptionOracleCaller.sol index bb012ae..a2b537c 100644 --- a/contracts/decryptionOracleLib/DecryptionOracleCaller.sol +++ b/contracts/decryptionOracleLib/DecryptionOracleCaller.sol @@ -15,7 +15,7 @@ interface IKMSVerifier { } interface IDecryptionOracle { - function requestDecryption(uint256[] calldata ctsHandles, bytes4 callbackSelector) external returns (uint256); + function requestDecryption(uint256 requestID, uint256[] calldata ctsHandles, bytes4 callbackSelector) external; } struct DecryptionOracleConfigStruct { @@ -28,6 +28,8 @@ abstract contract DecryptionOracleCaller { error InvalidKMSSignatures(); error UnsupportedHandleType(); + uint256 internal counterRequest; + mapping(uint256 => uint256[]) private requestedHandles; mapping(uint256 => ebool[]) private paramsEBool; mapping(uint256 => euint4[]) private paramsEUint4; mapping(uint256 => euint8[]) private paramsEUint8; @@ -37,9 +39,6 @@ abstract contract DecryptionOracleCaller { mapping(uint256 => eaddress[]) private paramsEAddress; mapping(uint256 => address[]) private paramsAddress; mapping(uint256 => uint256[]) private paramsUint256; - mapping(uint256 => uint256[]) private requestedHandles; - - constructor() {} function addParamsEBool(uint256 requestID, ebool _ebool) internal { paramsEBool[requestID].push(_ebool); @@ -199,11 +198,13 @@ abstract contract DecryptionOracleCaller { uint256[] memory ctsHandles, bytes4 callbackSelector ) internal returns (uint256 requestID) { + requestID = counterRequest; FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig(); IACL($.ACLAddress).allowForDecryption(ctsHandles); DecryptionOracleConfigStruct storage $$ = getDecryptionOracleConfig(); - requestID = IDecryptionOracle($$.DecryptionOracleAddress).requestDecryption(ctsHandles, callbackSelector); + IDecryptionOracle($$.DecryptionOracleAddress).requestDecryption(requestID, ctsHandles, callbackSelector); saveRequestedHandles(requestID, ctsHandles); + counterRequest++; } /// @dev this function should be called inside the callback function the dApp contract to verify the signatures diff --git a/contracts/test/asyncDecrypt.ts b/contracts/test/asyncDecrypt.ts index 8f3a460..9288f8c 100644 --- a/contracts/test/asyncDecrypt.ts +++ b/contracts/test/asyncDecrypt.ts @@ -40,7 +40,8 @@ if (networkName === 'hardhat') { relayer = new ethers.Wallet(privKeyRelayer!, ethers.provider); } -const argEvents = '(uint256 indexed requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector)'; +const argEvents = + '(uint256 indexed counter, uint256 requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector)'; const ifaceEventDecryption = new ethers.Interface(['event DecryptionRequest' + argEvents]); let decryptionOracle: DecryptionOracle; @@ -55,10 +56,15 @@ export const initDecryptionOracle = async (): Promise => { } // this function will emit logs for every request and fulfilment of a decryption decryptionOracle = await ethers.getContractAt('DecryptionOracle', parsedEnv.DECRYPTION_ORACLE_ADDRESS); - decryptionOracle.on('DecryptionRequest', async (requestID, cts, contractCaller, callbackSelector, eventData) => { - const blockNumber = eventData.log.blockNumber; - console.log(`${await currentTime()} - Requested decrypt on block ${blockNumber} (requestID ${requestID})`); - }); + decryptionOracle.on( + 'DecryptionRequest', + async (counter, requestID, cts, contractCaller, callbackSelector, eventData) => { + const blockNumber = eventData.log.blockNumber; + console.log( + `${await currentTime()} - Requested decrypt on block ${blockNumber} (counter ${counter} - requestID ${requestID})`, + ); + }, + ); }; export const awaitAllDecryptionResults = async (): Promise => { @@ -92,10 +98,10 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => { const pastRequests = await ethers.provider.getLogs(filterDecryption); for (const request of pastRequests) { const event = ifaceEventDecryption.parseLog(request); - const requestID = event.args[0]; - const handles = event.args[1]; - const contractCaller = event.args[2]; - const callbackSelector = event.args[3]; + const requestID = event.args[1]; + const handles = event.args[2]; + const contractCaller = event.args[3]; + const callbackSelector = event.args[4]; const typesList = handles.map((handle) => parseInt(handle.toString(16).slice(-4, -2), 16)); // if request is not already fulfilled if (mocked && !toSkip.includes(requestID)) {