From fdbbb0701951c490ec6cad628349252be9beb207 Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 13 Dec 2024 01:38:48 +0700 Subject: [PATCH] feat: include proxy ABIs in contract-type ABIs (#2413) --- src/ape/api/networks.py | 37 +++++++++ src/ape/managers/chain.py | 76 +++++++++++++++---- src/ape_ethereum/ecosystem.py | 10 +-- src/ape_ethereum/proxies.py | 11 ++- tests/functional/geth/conftest.py | 6 ++ tests/functional/geth/test_contracts_cache.py | 55 ++++++++++++++ tests/functional/geth/test_proxy.py | 24 ++++-- tests/functional/test_accounts.py | 24 ++++-- tests/functional/test_contract_container.py | 25 ++++-- tests/functional/test_contract_instance.py | 4 +- tests/functional/test_proxy.py | 5 +- 11 files changed, 235 insertions(+), 42 deletions(-) create mode 100644 tests/functional/geth/test_contracts_cache.py diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index f83510086f..b8938152da 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -65,6 +65,43 @@ class ProxyInfoAPI(BaseModel): target: AddressType """The address of the implementation contract.""" + type_name: str = "" + + @model_validator(mode="before") + @classmethod + def _validate_type_name(cls, model): + if "type_name" in model: + return model + + elif _type := model.get("type"): + # Attempt to figure out the type name. + if name := getattr(_type, "name", None): + # ProxyEnum - such as from 'ape-ethereum'. + model["type_name"] = name + else: + # Not sure. + try: + model["type_name"] = f"{_type}" + except Exception: + pass + + return model + + @log_instead_of_fail(default="") + def __repr__(self) -> str: + if _type := self.type_name: + return f"" + + return " Optional["MethodABI"]: + """ + Some proxies have special ABIs which may not exist in their + contract-types by default, such as Safe's ``masterCopy()``. + """ + return None + class EcosystemAPI(ExtraAttributesMixin, BaseInterfaceModel): """ diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index be87fb5ed4..0ba2379cad 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -796,28 +796,40 @@ def cache_deployment(self, contract_instance: ContractInstance): contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract to cache. """ - address = contract_instance.address - contract_type = contract_instance.contract_type + contract_type = contract_instance.contract_type # may be a proxy # Cache contract type in memory before proxy check, # in case it is needed somewhere. It may get overridden. self._local_contract_types[address] = contract_type - proxy_info = self.provider.network.ecosystem.get_proxy_info(address) - if proxy_info: + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address): + # The user is caching a deployment of a proxy with the target already set. self.cache_proxy_info(address, proxy_info) - contract_type = self.get(proxy_info.target) or contract_type - if contract_type: + if implementation_contract := self.get(proxy_info.target): + updated_proxy_contract = _get_combined_contract_type( + contract_type, proxy_info, implementation_contract + ) + self._cache_contract_type(address, updated_proxy_contract) + + # Use this contract type in the user's contract instance. + contract_instance.contract_type = updated_proxy_contract + + else: + # No implementation yet. Just cache proxy. self._cache_contract_type(address, contract_type) - return + else: + # Regular contract. Cache normally. + self._cache_contract_type(address, contract_type) + # Cache the deployment now. txn_hash = contract_instance.txn_hash - self._cache_contract_type(address, contract_type) if contract_type.name: self._cache_deployment(address, contract_type, txn_hash) + return contract_type + def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI): """ Cache proxy info for a particular address, useful for plugins adding already @@ -1058,7 +1070,6 @@ def get( Optional[ContractType]: The contract type if it was able to get one, otherwise the default parameter. """ - try: address_key: AddressType = self.conversion_manager.convert(address, AddressType) except ConversionError: @@ -1085,19 +1096,35 @@ def get( return default if not (contract_type := self._get_contract_type_from_disk(address_key)): - # Contract could be a minimal proxy + # Contract is not cached yet. Check broader sources, such as an explorer. + # First, detect if this is a proxy. proxy_info = self._local_proxies.get(address_key) or self._get_proxy_info_from_disk( address_key ) - if not proxy_info: proxy_info = self.provider.network.ecosystem.get_proxy_info(address_key) if proxy_info and self._is_live_network: self._cache_proxy_info_to_disk(address_key, proxy_info) if proxy_info: + # Contract is a proxy. self._local_proxies[address_key] = proxy_info - return self.get(proxy_info.target, default=default) + implementation_contract_type = self.get(proxy_info.target, default=default) + proxy_contract_type = ( + self._get_contract_type_from_explorer(address_key) + if fetch_from_explorer + else None + ) + if proxy_contract_type: + contract_type_to_cache = _get_combined_contract_type( + proxy_contract_type, proxy_info, implementation_contract_type + ) + else: + contract_type_to_cache = implementation_contract_type + + self._local_contract_types[address_key] = contract_type_to_cache + self._cache_contract_to_disk(address_key, contract_type_to_cache) + return contract_type_to_cache if not self.provider.get_code(address_key): if default: @@ -1271,6 +1298,7 @@ def instance_from_receipt( Args: receipt (:class:`~ape.api.transactions.ReceiptAPI`): The receipt. + contract_type (ContractType): The deployed contract type. Returns: :class:`~ape.contracts.base.ContractInstance` @@ -1347,14 +1375,14 @@ def _get_proxy_info_from_disk(self, address: AddressType) -> Optional[ProxyInfoA if not address_file.is_file(): return None - return ProxyInfoAPI.model_validate_json(address_file.read_text()) + return ProxyInfoAPI.model_validate_json(address_file.read_text(encoding="utf8")) def _get_blueprint_from_disk(self, blueprint_id: str) -> Optional[ContractType]: contract_file = self._blueprint_cache / f"{blueprint_id}.json" if not contract_file.is_file(): return None - return ContractType.model_validate_json(contract_file.read_text()) + return ContractType.model_validate_json(contract_file.read_text(encoding="utf8")) def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: if not self._network.explorer: @@ -1766,3 +1794,23 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI: raise TransactionNotFoundError(transaction_hash=transaction_hash) return receipt + + +def _get_combined_contract_type( + proxy_contract_type: ContractType, + proxy_info: ProxyInfoAPI, + implementation_contract_type: ContractType, +) -> ContractType: + proxy_abis = [ + abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function") + ] + + # Include "hidden" ABIs, such as Safe's `masterCopy()`. + if proxy_info.abi and proxy_info.abi.signature not in [ + abi.signature for abi in implementation_contract_type.abi + ]: + proxy_abis.append(proxy_info.abi) + + combined_contract_type = implementation_contract_type.model_copy(deep=True) + combined_contract_type.abi.extend(proxy_abis) + return combined_contract_type diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 50bf9cee65..165aef43b4 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -462,8 +462,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfo]: if isinstance(contract_code, bytes): contract_code = to_hex(contract_code) - code = contract_code[2:] - if not code: + if not (code := contract_code[2:]): return None patterns = { @@ -515,7 +514,7 @@ def str_to_slot(text): if _type == ProxyType.Beacon: target = ContractCall(IMPLEMENTATION_ABI, target)(skip_trace=True) - return ProxyInfo(type=_type, target=target) + return ProxyInfo(type=_type, target=target, abi=IMPLEMENTATION_ABI) # safe >=1.1.0 provides `masterCopy()`, which is also stored in slot 0 # call it and check that target matches @@ -525,7 +524,8 @@ def str_to_slot(text): target = self.conversion_manager.convert(slot_0[-20:], AddressType) # NOTE: `target` is set in initialized proxies if target != ZERO_ADDRESS and target == singleton: - return ProxyInfo(type=ProxyType.GnosisSafe, target=target) + return ProxyInfo(type=ProxyType.GnosisSafe, target=target, abi=MASTER_COPY_ABI) + except ApeException: pass @@ -541,7 +541,7 @@ def str_to_slot(text): target = ContractCall(IMPLEMENTATION_ABI, address)(skip_trace=True) # avoid recursion if target != ZERO_ADDRESS: - return ProxyInfo(type=ProxyType.Delegate, target=target) + return ProxyInfo(type=ProxyType.Delegate, target=target, abi=IMPLEMENTATION_ABI) except (ApeException, ValueError): pass diff --git a/src/ape_ethereum/proxies.py b/src/ape_ethereum/proxies.py index 387b7828db..d3205e05a8 100644 --- a/src/ape_ethereum/proxies.py +++ b/src/ape_ethereum/proxies.py @@ -1,5 +1,5 @@ from enum import IntEnum, auto -from typing import cast +from typing import Optional, cast from eth_pydantic_types.hex import HexStr from ethpm_types import ContractType, MethodABI @@ -69,6 +69,15 @@ class ProxyType(IntEnum): class ProxyInfo(ProxyInfoAPI): type: ProxyType + def __init__(self, **kwargs): + abi = kwargs.pop("abi", None) + super().__init__(**kwargs) + self._abi = abi + + @property + def abi(self) -> Optional[MethodABI]: + return self._abi + MASTER_COPY_ABI = MethodABI( type="function", diff --git a/tests/functional/geth/conftest.py b/tests/functional/geth/conftest.py index 97bdffb6e5..f7f7e30ad4 100644 --- a/tests/functional/geth/conftest.py +++ b/tests/functional/geth/conftest.py @@ -8,6 +8,12 @@ from tests.functional.data.python import TRACE_RESPONSE +@pytest.fixture(scope="session") +def safe_proxy_container(get_contract_type): + proxy_type = get_contract_type("SafeProxy") + return ContractContainer(proxy_type) + + @pytest.fixture def parity_trace_response(): return TRACE_RESPONSE diff --git a/tests/functional/geth/test_contracts_cache.py b/tests/functional/geth/test_contracts_cache.py new file mode 100644 index 0000000000..1c3cb5a6de --- /dev/null +++ b/tests/functional/geth/test_contracts_cache.py @@ -0,0 +1,55 @@ +import pytest + +from ape.exceptions import ContractNotFoundError +from tests.conftest import geth_process_test + + +@geth_process_test +def test_get_proxy_from_explorer( + mock_explorer, + create_mock_sepolia, + safe_proxy_container, + geth_account, + vyper_contract_container, + geth_provider, + chain, +): + """ + Simulated when you get a contract from Etherscan for the first time + but that contract is a proxy. We expect both proxy and target ABIs + to be cached under the proxy's address. + """ + target_contract = geth_account.deploy(vyper_contract_container, 10011339315) + proxy_contract = geth_account.deploy(safe_proxy_container, target_contract.address) + + # Ensure both of these are not cached so we have to rely on our fake explorer. + del chain.contracts[target_contract.address] + del chain.contracts[proxy_contract.address] + # Sanity check. + with pytest.raises(ContractNotFoundError): + _ = chain.contracts.instance_at(proxy_contract.address) + + def get_contract_type(address, *args, **kwargs): + # Mock etherscan backend. + if address == target_contract.address: + return target_contract.contract_type + elif address == proxy_contract.address: + return proxy_contract.contract_type + + raise ValueError("Fake explorer only knows about proxy and target contracts.") + + with create_mock_sepolia() as network: + # Setup our network to use our fake explorer. + mock_explorer.get_contract_type.side_effect = get_contract_type + network.__dict__["explorer"] = mock_explorer + + # Typical flow: user attempts to get an un-cached contract type from Etherscan. + # That contract may be a proxy, in which case we should get a type + # w/ both proxy ABIs and the target ABIs. + contract_from_explorer = chain.contracts.instance_at(proxy_contract.address) + + # Ensure we can call proxy methods! + assert contract_from_explorer.masterCopy # No attr error! + + # Ensure we can call target methods! + assert contract_from_explorer.myNumber # No attr error! diff --git a/tests/functional/geth/test_proxy.py b/tests/functional/geth/test_proxy.py index 4d39b186ae..f58e20685b 100644 --- a/tests/functional/geth/test_proxy.py +++ b/tests/functional/geth/test_proxy.py @@ -57,20 +57,28 @@ def test_uups_proxy(get_contract_type, geth_contract, owner, ethereum): @geth_process_test -def test_gnosis_safe(get_contract_type, geth_contract, owner, ethereum): - _type = get_contract_type("SafeProxy") - contract = ContractContainer(_type) - +def test_gnosis_safe(safe_proxy_container, geth_contract, owner, ethereum, chain): + # Setup a proxy contract. target = geth_contract.address + proxy_instance = owner.deploy(safe_proxy_container, target) - contract_instance = owner.deploy(contract, target) - - actual = ethereum.get_proxy_info(contract_instance.address) - + # (test) + actual = ethereum.get_proxy_info(proxy_instance.address) assert actual is not None assert actual.type == ProxyType.GnosisSafe assert actual.target == target + # Ensure we can call the proxy-method. + assert proxy_instance.masterCopy() + + # Ensure we can call target methods. + assert isinstance(proxy_instance.myNumber(), int) + + # Ensure this works with new instances. + proxy_instance_ref_2 = chain.contracts.instance_at(proxy_instance.address) + assert proxy_instance_ref_2.masterCopy() + assert isinstance(proxy_instance_ref_2.myNumber(), int) + @geth_process_test def test_openzeppelin(get_contract_type, geth_contract, owner, ethereum, sender): diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index 6f023d09d3..7a7ecfe41c 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -286,16 +286,30 @@ def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, m def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain): target = vyper_contract_instance.address proxy = owner.deploy(proxy_contract_container, target) + + # Ensure we can call both proxy and target methods on it. + assert proxy.implementation # No attr err + assert proxy.myNumber # No attr err + + # Ensure was properly cached. assert proxy.address in chain.contracts._local_contract_types assert proxy.address in chain.contracts._local_proxies - actual = chain.contracts._local_proxies[proxy.address] - assert actual.target == target - assert actual.type == ProxyType.Delegate + # Show the cached proxy info is correct. + proxy_info = chain.contracts._local_proxies[proxy.address] + assert proxy_info.target == target + assert proxy_info.type == ProxyType.Delegate + assert proxy_info.abi.name == "implementation" # Show we get the implementation contract type using the proxy address - implementation = chain.contracts.instance_at(proxy.address) - assert implementation.contract_type == vyper_contract_instance.contract_type + re_contract = chain.contracts.instance_at(proxy.address) + assert re_contract.contract_type == proxy.contract_type + + # Show proxy methods are not available on target alone. + target = chain.contracts.instance_at(proxy_info.target) + assert target.myNumber # No attr err + with pytest.raises(AttributeError): + _ = target.implementation def test_deploy_instance(owner, vyper_contract_instance): diff --git a/tests/functional/test_contract_container.py b/tests/functional/test_contract_container.py index 45e74087bc..870b21e5c7 100644 --- a/tests/functional/test_contract_container.py +++ b/tests/functional/test_contract_container.py @@ -100,20 +100,33 @@ def test_deployments(owner, eth_tester_provider, vyper_contract_container): def test_deploy_proxy( - owner, project, vyper_contract_instance, proxy_contract_container, chain, eth_tester_provider + owner, vyper_contract_instance, proxy_contract_container, chain, eth_tester_provider ): target = vyper_contract_instance.address proxy = proxy_contract_container.deploy(target, sender=owner) + + # Ensure we can call both proxy and target methods on it. + assert proxy.implementation # No attr err + assert proxy.myNumber # No attr err + + # Ensure caching works. assert proxy.address in chain.contracts._local_contract_types assert proxy.address in chain.contracts._local_proxies - actual = chain.contracts._local_proxies[proxy.address] - assert actual.target == target - assert actual.type == ProxyType.Delegate + # Show the cached proxy info is correct. + proxy_info = chain.contracts._local_proxies[proxy.address] + assert proxy_info.target == target + assert proxy_info.type == ProxyType.Delegate # Show we get the implementation contract type using the proxy address - implementation = chain.contracts.instance_at(proxy.address) - assert implementation.contract_type == vyper_contract_instance.contract_type + re_contract = chain.contracts.instance_at(proxy.address) + assert re_contract.contract_type == proxy.contract_type + + # Show proxy methods are not available on target alone. + target = chain.contracts.instance_at(proxy_info.target) + assert target.myNumber # No attr err + with pytest.raises(AttributeError): + _ = target.implementation def test_source_path_in_project(project_with_contract): diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index b0e50cb156..e3f5deb562 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -819,7 +819,7 @@ def test_get_error_by_signature(error_contract): def test_selector_identifiers(vyper_contract_instance): - assert len(vyper_contract_instance.selector_identifiers.keys()) == 54 + assert len(vyper_contract_instance.selector_identifiers.keys()) >= 54 assert vyper_contract_instance.selector_identifiers["balances(address)"] == "0x27e235e3" assert vyper_contract_instance.selector_identifiers["owner()"] == "0x8da5cb5b" assert ( @@ -829,7 +829,7 @@ def test_selector_identifiers(vyper_contract_instance): def test_identifier_lookup(vyper_contract_instance): - assert len(vyper_contract_instance.identifier_lookup.keys()) == 54 + assert len(vyper_contract_instance.identifier_lookup.keys()) >= 54 assert vyper_contract_instance.identifier_lookup["0x27e235e3"].selector == "balances(address)" assert vyper_contract_instance.identifier_lookup["0x8da5cb5b"].selector == "owner()" assert ( diff --git a/tests/functional/test_proxy.py b/tests/functional/test_proxy.py index 6d4302bd04..3a8bd6bd7f 100644 --- a/tests/functional/test_proxy.py +++ b/tests/functional/test_proxy.py @@ -5,9 +5,12 @@ """ -def test_minimal_proxy(ethereum, minimal_proxy): +def test_minimal_proxy(ethereum, minimal_proxy, chain): actual = ethereum.get_proxy_info(minimal_proxy.address) assert actual is not None assert actual.type == ProxyType.Minimal # It is the placeholder value still. assert actual.target == "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe" + # Show getting the contract using the proxy address. + contract = chain.contracts.instance_at(minimal_proxy.address) + assert contract.contract_type.abi == [] # No target ABIs; no proxy ABIs either.