Skip to content

Commit

Permalink
Merge pull request #220 from zama-ai/counterInEvent
Browse files Browse the repository at this point in the history
feat: requestID is generated by dApp
  • Loading branch information
jatZama authored Dec 27, 2024
2 parents 2224167 + 09bbe79 commit cdf2525
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 20 deletions.
29 changes: 23 additions & 6 deletions contracts/decryptionOracle/DecryptionOracle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,20 @@ contract DecryptionOracle is UUPSUpgradeable, Ownable2StepUpgradeable {
/// @notice Patch version of the contract.
uint256 private constant PATCH_VERSION = 0;

event DecryptionRequest(uint256 indexed requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector);
/**
* @dev Event emitted during each decryption request. The off-chain gateway service is listening to it.
*/
event DecryptionRequest(
uint256 indexed counter,
uint256 requestID,
uint256[] cts,
address contractCaller,
bytes4 callbackSelector
);

/**
* @dev Should revert when msg.sender is not authorized to upgrade the contract.
*/
function _authorizeUpgrade(address _newImplementation) internal virtual override onlyOwner {}

/// @custom:storage-location erc7201:fhevm.storage.DecryptionOracle
Expand Down Expand Up @@ -51,22 +63,27 @@ contract DecryptionOracle is UUPSUpgradeable, Ownable2StepUpgradeable {
_disableInitializers();
}

function initialize(address _decryptionOracleOwner) external initializer {
__Ownable_init(_decryptionOracleOwner);
/**
* @notice Initializes the contract.
* @param initialOwner Initial owner address.
*/
function initialize(address initialOwner) external initializer {
__Ownable_init(initialOwner);
}

/** @notice Requests the decryption of n ciphertexts `ctsHandles` with the result returned in a callback.
* @notice During callback, msg.sender is called with [callbackSelector,requestID,decrypt(ctsHandles[0]),decrypt(ctsHandles[1]),...,decrypt(ctsHandles[n-1]),signatures]
* @param requestID is the request index generated by the dApp requesting the decryption.
* @param ctsHandles is an array of uint256s handles.
* @param callbackSelector the callback selector to be called on msg.sender later during fulfilment
*/
function requestDecryption(
uint256 requestID,
uint256[] calldata ctsHandles,
bytes4 callbackSelector
) external virtual returns (uint256 requestID) {
) external virtual {
DecryptionOracleStorage storage $ = _getDecryptionOracleStorage();
requestID = uint256(keccak256(abi.encodePacked(msg.sender, $.counter)));
emit DecryptionRequest(requestID, ctsHandles, msg.sender, callbackSelector);
emit DecryptionRequest($.counter, requestID, ctsHandles, msg.sender, callbackSelector);
$.counter++;
}

Expand Down
11 changes: 6 additions & 5 deletions contracts/decryptionOracleLib/DecryptionOracleCaller.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ interface IKMSVerifier {
}

interface IDecryptionOracle {
function requestDecryption(uint256[] calldata ctsHandles, bytes4 callbackSelector) external returns (uint256);
function requestDecryption(uint256 requestID, uint256[] calldata ctsHandles, bytes4 callbackSelector) external;
}

struct DecryptionOracleConfigStruct {
Expand All @@ -28,6 +28,8 @@ abstract contract DecryptionOracleCaller {
error InvalidKMSSignatures();
error UnsupportedHandleType();

uint256 internal counterRequest;
mapping(uint256 => uint256[]) private requestedHandles;
mapping(uint256 => ebool[]) private paramsEBool;
mapping(uint256 => euint4[]) private paramsEUint4;
mapping(uint256 => euint8[]) private paramsEUint8;
Expand All @@ -37,9 +39,6 @@ abstract contract DecryptionOracleCaller {
mapping(uint256 => eaddress[]) private paramsEAddress;
mapping(uint256 => address[]) private paramsAddress;
mapping(uint256 => uint256[]) private paramsUint256;
mapping(uint256 => uint256[]) private requestedHandles;

constructor() {}

function addParamsEBool(uint256 requestID, ebool _ebool) internal {
paramsEBool[requestID].push(_ebool);
Expand Down Expand Up @@ -199,11 +198,13 @@ abstract contract DecryptionOracleCaller {
uint256[] memory ctsHandles,
bytes4 callbackSelector
) internal returns (uint256 requestID) {
requestID = counterRequest;
FHEVMConfig.FHEVMConfigStruct storage $ = Impl.getFHEVMConfig();
IACL($.ACLAddress).allowForDecryption(ctsHandles);
DecryptionOracleConfigStruct storage $$ = getDecryptionOracleConfig();
requestID = IDecryptionOracle($$.DecryptionOracleAddress).requestDecryption(ctsHandles, callbackSelector);
IDecryptionOracle($$.DecryptionOracleAddress).requestDecryption(requestID, ctsHandles, callbackSelector);
saveRequestedHandles(requestID, ctsHandles);
counterRequest++;
}

/// @dev this function should be called inside the callback function the dApp contract to verify the signatures
Expand Down
24 changes: 15 additions & 9 deletions contracts/test/asyncDecrypt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ if (networkName === 'hardhat') {
relayer = new ethers.Wallet(privKeyRelayer!, ethers.provider);
}

const argEvents = '(uint256 indexed requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector)';
const argEvents =
'(uint256 indexed counter, uint256 requestID, uint256[] cts, address contractCaller, bytes4 callbackSelector)';
const ifaceEventDecryption = new ethers.Interface(['event DecryptionRequest' + argEvents]);

let decryptionOracle: DecryptionOracle;
Expand All @@ -55,10 +56,15 @@ export const initDecryptionOracle = async (): Promise<void> => {
}
// this function will emit logs for every request and fulfilment of a decryption
decryptionOracle = await ethers.getContractAt('DecryptionOracle', parsedEnv.DECRYPTION_ORACLE_ADDRESS);
decryptionOracle.on('DecryptionRequest', async (requestID, cts, contractCaller, callbackSelector, eventData) => {
const blockNumber = eventData.log.blockNumber;
console.log(`${await currentTime()} - Requested decrypt on block ${blockNumber} (requestID ${requestID})`);
});
decryptionOracle.on(
'DecryptionRequest',
async (counter, requestID, cts, contractCaller, callbackSelector, eventData) => {
const blockNumber = eventData.log.blockNumber;
console.log(
`${await currentTime()} - Requested decrypt on block ${blockNumber} (counter ${counter} - requestID ${requestID})`,
);
},
);
};

export const awaitAllDecryptionResults = async (): Promise<void> => {
Expand Down Expand Up @@ -92,10 +98,10 @@ const fulfillAllPastRequestsIds = async (mocked: boolean) => {
const pastRequests = await ethers.provider.getLogs(filterDecryption);
for (const request of pastRequests) {
const event = ifaceEventDecryption.parseLog(request);
const requestID = event.args[0];
const handles = event.args[1];
const contractCaller = event.args[2];
const callbackSelector = event.args[3];
const requestID = event.args[1];
const handles = event.args[2];
const contractCaller = event.args[3];
const callbackSelector = event.args[4];
const typesList = handles.map((handle) => parseInt(handle.toString(16).slice(-4, -2), 16));
// if request is not already fulfilled
if (mocked && !toSkip.includes(requestID)) {
Expand Down

0 comments on commit cdf2525

Please sign in to comment.