diff --git a/contracts/Orchestrator.sol b/contracts/Orchestrator.sol index 4ef4b160..7e0bdc68 100644 --- a/contracts/Orchestrator.sol +++ b/contracts/Orchestrator.sol @@ -22,6 +22,8 @@ contract Orchestrator is Ownable { // Stable ordering is not guaranteed. Transaction[] public transactions; + address[] private whitelist; + IUFragmentsPolicy public policy; /** @@ -42,7 +44,24 @@ contract Orchestrator is Ownable { */ function rebase() external { require(msg.sender == tx.origin); // solhint-disable-line avoid-tx-origin + _rebase(); + } + /** + * @notice Contract account entry point to initiate rebase operation. + * The + */ + function rebaseFromContract() external { + for (uint256 i = 0; i < whitelist.length; i++) { + if (whitelist[i] == msg.sender) { + _rebase(); + return; + } + } + revert("rebaseFromContract called from non-whitelisted account"); + } + + function _rebase() private { policy.rebase(); for (uint256 i = 0; i < transactions.length; i++) { @@ -94,4 +113,32 @@ contract Orchestrator is Ownable { function transactionsSize() external view returns (uint256) { return transactions.length; } + + /** + * @param account Address of contract to whitelist. + */ + function whitelistAccount(address account) external onlyOwner { + for (uint256 i = 0; i < whitelist.length; i++) { + if (whitelist[i] == account) { + return; + } + } + whitelist.push(account); + } + + /** + * @param account Address of contract to remove from whitelist. + */ + function unlistAccount(address account) external onlyOwner { + for (uint256 i = whitelist.length; i > 0; i--) { + if (whitelist[i - 1] == account) { + whitelist[i - 1] = whitelist[whitelist.length - 1]; + whitelist.pop(); + } + } + } + + function getWhitelist() external view returns (address[] memory) { + return whitelist; + } } diff --git a/contracts/mocks/RebaseCallerContract.sol b/contracts/mocks/RebaseCallerContract.sol index d13ba3d0..6b3d62d4 100644 --- a/contracts/mocks/RebaseCallerContract.sol +++ b/contracts/mocks/RebaseCallerContract.sol @@ -10,4 +10,9 @@ contract RebaseCallerContract { // pay back flash loan. return true; } + + function callRebaseFromContract(address orchestrator) public returns (bool) { + Orchestrator(orchestrator).rebaseFromContract(); + return true; + } } diff --git a/test/unit/Orchestrator.ts b/test/unit/Orchestrator.ts index 6744b1e9..966989b9 100644 --- a/test/unit/Orchestrator.ts +++ b/test/unit/Orchestrator.ts @@ -7,6 +7,7 @@ import { TransactionResponse } from '@ethersproject/providers' let orchestrator: Contract, mockPolicy: Contract, mockDownstream: Contract let r: Promise let deployer: Signer, user: Signer +let rebaseCallerContracts: Contract[] async function mockedOrchestrator() { await increaseTime(86400) @@ -26,12 +27,24 @@ async function mockedOrchestrator() { ) .connect(deployer) .deploy() + + const rebaseCallerContracts = [] + for (let i = 0; i < 3; i++) { + const _rebaseCallerContract = await ( + await ethers.getContractFactory('RebaseCallerContract') + ) + .connect(deployer) + .deploy() + rebaseCallerContracts.push(_rebaseCallerContract) + } + return { deployer, user, orchestrator, mockPolicy, mockDownstream, + rebaseCallerContracts, } } @@ -354,4 +367,102 @@ describe('Orchestrator', function () { }) }) }) + + describe('whitelist functionality', async function () { + before('setup rebase caller contracts', async () => { + ;({ rebaseCallerContracts } = await waffle.loadFixture( + mockedOrchestrator, + )) + }) + + it('should return an empty list', async function () { + expect((await orchestrator.getWhitelist()).length).to.eq(0) + }) + + describe('adding to whitelist', async function () { + it('should add contract address to whitelist', async function () { + await orchestrator + .connect(deployer) + .whitelistAccount(rebaseCallerContracts[0].address) + const whitelist = await orchestrator.getWhitelist() + expect(whitelist.length).to.eq(1) + expect(whitelist[0]).to.eq(rebaseCallerContracts[0].address) + }) + + it('should disallow non-owner from adding contract address to whitelist', async function () { + await expect( + orchestrator + .connect(user) + .whitelistAccount(rebaseCallerContracts[1].address), + ).to.be.reverted + const whitelist = await orchestrator.getWhitelist() + expect(whitelist.length).to.eq(1) + expect(whitelist[0]).to.eq(rebaseCallerContracts[0].address) + }) + }) + + describe('rebasing from whitelist', async function () { + it('should allow whitelisted contract entry to rebaseFromContract', async function () { + await expect( + rebaseCallerContracts[0].callRebaseFromContract(orchestrator.address), + ).to.not.be.reverted + }) + + it('should disallow whitelisted contract entry to eoa rebase entrypoint', async function () { + await expect(rebaseCallerContracts[0].callRebase(orchestrator.address)) + .to.be.reverted + }) + + it('should rebase....', async function () { + console.log('todo') + }) + }) + + describe('removing from whitelist', async function () { + it('should disallow non-owner from removing contract address from whitelist', async function () { + await expect( + orchestrator + .connect(user) + .unlistAccount(rebaseCallerContracts[0].address), + ).to.be.reverted + const whitelist = await orchestrator.getWhitelist() + expect(whitelist.length).to.eq(1) + expect(whitelist[0]).to.eq(rebaseCallerContracts[0].address) + }) + + it('should remove address from whitelist', async function () { + await orchestrator + .connect(deployer) + .unlistAccount(rebaseCallerContracts[0].address) + const whitelist = await orchestrator.getWhitelist() + expect(whitelist.length).to.eq(0) + }) + + it('should disallow unlisted contract entry to rebaseFromContract', async function () { + await expect( + rebaseCallerContracts[0].callRebaseFromContract(orchestrator.address), + ).to.be.reverted + }) + }) + + describe('handling longer whitelist', async function () { + it('should add and allow multiple whitelisted contracts entry to rebaseFromContract', async function () { + for (let i = 0; i < rebaseCallerContracts.length; i++) { + await orchestrator + .connect(deployer) + .whitelistAccount(rebaseCallerContracts[i].address) + } + const whitelist = await orchestrator.getWhitelist() + expect(whitelist.length).to.eq(3) + for (let i = 0; i < rebaseCallerContracts.length; i++) { + expect(whitelist[i]).to.eq(rebaseCallerContracts[i].address) + await expect( + rebaseCallerContracts[i].callRebaseFromContract( + orchestrator.address, + ), + ).to.not.be.reverted + } + }) + }) + }) })