Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diamond refactor #65

Merged
merged 14 commits into from
Oct 9, 2023
120 changes: 105 additions & 15 deletions contracts/diamond/Diamond.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,26 @@ contract Diamond is DiamondStorage {
using EnumerableSet for EnumerableSet.Bytes32Set;
using EnumerableSet for EnumerableSet.AddressSet;

enum FacetAction {
Add,
Replace,
Remove
}

struct Facet {
address facetAddress;
FacetAction action;
bytes4[] functionSelectors;
}

event DiamondCut(Facet[] facets, address initFacet, bytes initData);

/**
* @notice The payable fallback function that delegatecall's the facet with associated selector
*/
// solhint-disable-next-line
fallback() external payable virtual {
address facet_ = getFacetBySelector(msg.sig);
address facet_ = facetAddress(msg.sig);

require(facet_ != address(0), "Diamond: selector is not registered");

Expand All @@ -50,14 +64,46 @@ contract Diamond is DiamondStorage {
}
}

/**
* @notice Add/replace/remove any number of functions and optionally execute a function with delegatecall
* @param facets_ Contains the facet addresses and function selectors
* @param initFacet_ The address of the contract or facet to execute initData_
* @param initData_ A function call, including function selector and arguments initData_ is executed with delegatecall on initFacet_
*/
function _diamondCut(
Facet[] memory facets_,
address initFacet_,
bytes memory initData_
) internal virtual {
for (uint256 i; i < facets_.length; i++) {
bytes4[] memory _functionSelectors = facets_[i].functionSelectors;
address _facetAddress = facets_[i].facetAddress;

FacetAction _action = facets_[i].action;

if (_action == FacetAction.Add) {
_addFacet(_facetAddress, _functionSelectors);
} else if (_action == FacetAction.Remove) {
_removeFacet(_facetAddress, _functionSelectors);
} else {
_updateFacet(_facetAddress, _functionSelectors);
}
}

emit DiamondCut(facets_, initFacet_, initData_);

_initializeDiamondCut(initFacet_, initData_);
}

/**
* @notice The internal function to add facets to a diamond (aka diamondCut())
* @param facet_ the implementation address
* @param selectors_ the function selectors the implementation has
*/
function _addFacet(address facet_, bytes4[] memory selectors_) internal {
function _addFacet(address facet_, bytes4[] memory selectors_) internal virtual {
require(facet_ != address(0), "Diamond: facet cannot be zero address");
require(facet_.isContract(), "Diamond: facet is not a contract");
require(selectors_.length > 0, "Diamond: no selectors provided");
require(selectors_.length != 0, "Diamond: no selectors provided");

DStorage storage _ds = _getDiamondStorage();

Expand All @@ -79,8 +125,9 @@ contract Diamond is DiamondStorage {
* @param facet_ the implementation to be removed. The facet itself will be removed only if there are no selectors left
* @param selectors_ the selectors of that implementation to be removed
*/
function _removeFacet(address facet_, bytes4[] memory selectors_) internal {
require(selectors_.length > 0, "Diamond: no selectors provided");
function _removeFacet(address facet_, bytes4[] memory selectors_) internal virtual {
require(facet_ != address(0), "Diamond: facet cannot be zero address");
require(selectors_.length != 0, "Diamond: no selectors provided");

DStorage storage _ds = _getDiamondStorage();

Expand All @@ -100,18 +147,61 @@ contract Diamond is DiamondStorage {
}

/**
* @notice The internal function to update the facets of the diamond
* @notice The internal function to update the facet selectors of the diamond
* @param facet_ the facet to update
* @param fromSelectors_ the selectors to remove from the facet
* @param toSelectors_ the selectors to add to the facet
* @param selectors_ the selectors of the facet
*/
function _updateFacet(
address facet_,
bytes4[] memory fromSelectors_,
bytes4[] memory toSelectors_
) internal {
_addFacet(facet_, toSelectors_);
_removeFacet(facet_, fromSelectors_);
function _updateFacet(address facet_, bytes4[] memory selectors_) internal virtual {
require(facet_ != address(0), "Diamond: facet cannot be zero address");
require(facet_.isContract(), "Diamond: facet is not a contract");
require(selectors_.length != 0, "Diamond: no selectors provided");

DStorage storage _ds = _getDiamondStorage();

for (uint256 i; i < selectors_.length; i++) {
bytes4 selector_ = selectors_[i];
address oldFacet_ = facetAddress(selector_);

require(oldFacet_ != facet_, "Diamond: cannot replace to the same facet");
require(oldFacet_ != address(0), "Diamond: no facet found for selector");

// replace old facet address
_ds.selectorToFacet[selector_] = facet_;
_ds.facetToSelectors[facet_].add(bytes32(selector_));

// remove old facet address
_ds.facetToSelectors[oldFacet_].remove(bytes32(selector_));

if (_ds.facetToSelectors[oldFacet_].length() == 0) {
_ds.facets.remove(oldFacet_);
}
}

_ds.facets.add(facet_);
}

/**
* @notice The internal function to initialize the diamond cut.
* @param initFacet_ the address of the contract or facet to execute initData_
* @param initData_ a function call, including function selector and arguments, to be executed with delegatecall on initFacet_
*/
function _initializeDiamondCut(address initFacet_, bytes memory initData_) internal virtual {
if (initFacet_ == address(0)) {
return;
}

require(initFacet_.isContract(), "Diamond: init_ address has no code");

(bool success_, bytes memory err_) = initFacet_.delegatecall(initData_);

if (!success_) {
require(err_.length > 0, "Diamond: initialization function reverted");
// bubble up error
// @solidity memory-safe-assembly
assembly {
revert(add(32, err_), mload(err_))
}
}
}

function _beforeFallback(address facet_, bytes4 selector_) internal virtual {}
Expand Down
36 changes: 30 additions & 6 deletions contracts/diamond/DiamondStorage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ abstract contract DiamondStorage {
EnumerableSet.AddressSet facets;
}

struct FacetInfo {
address facetAddress;
bytes4[] functionSelectors;
}

/**
* @notice The internal function to get the diamond proxy storage
* @return _ds the struct from the DIAMOND_STORAGE_SLOT
Expand All @@ -39,19 +44,30 @@ abstract contract DiamondStorage {
}

/**
* @notice The function to get all the facets of this diamond
* @return facets_ the array of facets' addresses
* @notice The function to get all the facets and their selectors
* @return facets_ the array of FacetInfo
*/
function getFacets() public view returns (address[] memory facets_) {
return _getDiamondStorage().facets.values();
function facets() public view returns (FacetInfo[] memory facets_) {
EnumerableSet.AddressSet storage _facets = _getDiamondStorage().facets;

facets_ = new FacetInfo[](_facets.length());

for (uint256 i = 0; i < facets_.length; i++) {
address facet_ = _facets.at(i);

facets_[i].facetAddress = facet_;
facets_[i].functionSelectors = facetFunctionSelectors(facet_);
}
}

/**
* @notice The function to get all the selectors assigned to the facet
* @param facet_ the facet to get assigned selectors of
* @return selectors_ the array of assigned selectors
*/
function getFacetSelectors(address facet_) public view returns (bytes4[] memory selectors_) {
function facetFunctionSelectors(
address facet_
) public view returns (bytes4[] memory selectors_) {
EnumerableSet.Bytes32Set storage _f2s = _getDiamondStorage().facetToSelectors[facet_];

selectors_ = new bytes4[](_f2s.length());
Expand All @@ -61,12 +77,20 @@ abstract contract DiamondStorage {
}
}

/**
* @notice The function to get all the facets of this diamond
* @return facets_ the array of facets' addresses
*/
function facetAddresses() public view returns (address[] memory facets_) {
return _getDiamondStorage().facets.values();
}

/**
* @notice The function to get associated facet by the selector
* @param selector_ the selector
* @return facet_ the associated facet address
*/
function getFacetBySelector(bytes4 selector_) public view returns (address facet_) {
function facetAddress(bytes4 selector_) public view returns (address facet_) {
return _getDiamondStorage().selectorToFacet[selector_];
}
}
19 changes: 19 additions & 0 deletions contracts/diamond/introspection/DiamondERC165.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

import {ERC165} from "@openzeppelin/contracts/utils/introspection/ERC165.sol";

/**
* @notice DiamondERC165 - Contract implementing ERC165 interface for Diamonds
*/
contract DiamondERC165 is ERC165 {
function supportsInterface(bytes4 interfaceId) public view virtual override returns (bool) {
// This section of code provides support for the Diamond Loupe and Diamond Cut interfaces.
// Diamond Loupe interface is defined as: 0x48e2b093
// Diamond Cut interface is defined as: 0x1f931c1c
return
interfaceId == 0x1f931c1c ||
interfaceId == 0x48e2b093 ||
super.supportsInterface(interfaceId);
}
}
20 changes: 8 additions & 12 deletions contracts/diamond/presets/OwnableDiamond/OwnableDiamond.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,15 @@ contract OwnableDiamond is Diamond, OwnableDiamondStorage {
_getOwnableDiamondStorage().owner = newOwner_;
}

function addFacet(address facet_, bytes4[] memory selectors_) public virtual onlyOwner {
_addFacet(facet_, selectors_);
function diamondCut(Facet[] memory facets_) public onlyOwner {
diamondCut(facets_, address(0), "");
}

function removeFacet(address facet_, bytes4[] memory selectors_) public virtual onlyOwner {
_removeFacet(facet_, selectors_);
}

function updateFacet(
address facet_,
bytes4[] memory fromSelectors_,
bytes4[] memory toSelectors_
) public virtual onlyOwner {
_updateFacet(facet_, fromSelectors_, toSelectors_);
function diamondCut(
Facet[] memory facets_,
address init_,
bytes memory initData_
) public onlyOwner {
_diamondCut(facets_, init_, initData_);
}
}
2 changes: 1 addition & 1 deletion contracts/mock/diamond/DummyFacet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pragma solidity ^0.8.4;
import {DummyStorage} from "./DummyStorage.sol";

contract DummyFacet is DummyStorage {
function setDummyString(string calldata dummyString_) external {
function setDummyString(string memory dummyString_) public {
getDummyFacetStorage().dummyString = dummyString_;
}

Expand Down
21 changes: 21 additions & 0 deletions contracts/mock/diamond/DummyInit.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

import {DummyFacet} from "./DummyFacet.sol";

contract DummyInit is DummyFacet {
event Initialized();

function init() external {
setDummyString("dummy facet initialized");
emit Initialized();
}

function initWithError() external pure {
revert();
}

function initWithErrorMsg() external pure {
revert("DiamondInit: init error");
}
}
14 changes: 14 additions & 0 deletions contracts/mock/diamond/OwnableDiamondMock.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

import {OwnableDiamond} from "../../diamond/presets/OwnableDiamond/OwnableDiamond.sol";

contract OwnableDiamondMock is OwnableDiamond {
function diamondCutShort(Facet[] memory facets_) public {
diamondCut(facets_);
}

function diamondCutLong(Facet[] memory facets_, address init_, bytes memory calldata_) public {
diamondCut(facets_, init_, calldata_);
}
}
2 changes: 1 addition & 1 deletion contracts/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/solidity-lib",
"version": "2.5.11",
"version": "2.5.12",
"license": "MIT",
"author": "Distributed Lab",
"readme": "README.md",
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/solidity-lib",
"version": "2.5.11",
"version": "2.5.12",
"license": "MIT",
"author": "Distributed Lab",
"description": "Solidity Library by Distributed Lab",
Expand Down
Loading