Skip to content

Commit

Permalink
feat: include proxy ABIs in contract-type ABIs (#2413)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Dec 12, 2024
1 parent 98db2da commit fdbbb07
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 42 deletions.
37 changes: 37 additions & 0 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,43 @@ class ProxyInfoAPI(BaseModel):
target: AddressType
"""The address of the implementation contract."""

type_name: str = ""

@model_validator(mode="before")
@classmethod
def _validate_type_name(cls, model):
if "type_name" in model:
return model

elif _type := model.get("type"):
# Attempt to figure out the type name.
if name := getattr(_type, "name", None):
# ProxyEnum - such as from 'ape-ethereum'.
model["type_name"] = name
else:
# Not sure.
try:
model["type_name"] = f"{_type}"
except Exception:
pass

return model

@log_instead_of_fail(default="<ProxyInfoAPI>")
def __repr__(self) -> str:
if _type := self.type_name:
return f"<Proxy {_type} target={self.target}>"

return "<Proxy target={self.target}"

@property
def abi(self) -> Optional["MethodABI"]:
"""
Some proxies have special ABIs which may not exist in their
contract-types by default, such as Safe's ``masterCopy()``.
"""
return None


class EcosystemAPI(ExtraAttributesMixin, BaseInterfaceModel):
"""
Expand Down
76 changes: 62 additions & 14 deletions src/ape/managers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,28 +796,40 @@ def cache_deployment(self, contract_instance: ContractInstance):
contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract
to cache.
"""

address = contract_instance.address
contract_type = contract_instance.contract_type
contract_type = contract_instance.contract_type # may be a proxy

# Cache contract type in memory before proxy check,
# in case it is needed somewhere. It may get overridden.
self._local_contract_types[address] = contract_type

proxy_info = self.provider.network.ecosystem.get_proxy_info(address)
if proxy_info:
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address):
# The user is caching a deployment of a proxy with the target already set.
self.cache_proxy_info(address, proxy_info)
contract_type = self.get(proxy_info.target) or contract_type
if contract_type:
if implementation_contract := self.get(proxy_info.target):
updated_proxy_contract = _get_combined_contract_type(
contract_type, proxy_info, implementation_contract
)
self._cache_contract_type(address, updated_proxy_contract)

# Use this contract type in the user's contract instance.
contract_instance.contract_type = updated_proxy_contract

else:
# No implementation yet. Just cache proxy.
self._cache_contract_type(address, contract_type)

return
else:
# Regular contract. Cache normally.
self._cache_contract_type(address, contract_type)

# Cache the deployment now.
txn_hash = contract_instance.txn_hash
self._cache_contract_type(address, contract_type)
if contract_type.name:
self._cache_deployment(address, contract_type, txn_hash)

return contract_type

def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI):
"""
Cache proxy info for a particular address, useful for plugins adding already
Expand Down Expand Up @@ -1058,7 +1070,6 @@ def get(
Optional[ContractType]: The contract type if it was able to get one,
otherwise the default parameter.
"""

try:
address_key: AddressType = self.conversion_manager.convert(address, AddressType)
except ConversionError:
Expand All @@ -1085,19 +1096,35 @@ def get(
return default

if not (contract_type := self._get_contract_type_from_disk(address_key)):
# Contract could be a minimal proxy
# Contract is not cached yet. Check broader sources, such as an explorer.
# First, detect if this is a proxy.
proxy_info = self._local_proxies.get(address_key) or self._get_proxy_info_from_disk(
address_key
)

if not proxy_info:
proxy_info = self.provider.network.ecosystem.get_proxy_info(address_key)
if proxy_info and self._is_live_network:
self._cache_proxy_info_to_disk(address_key, proxy_info)

if proxy_info:
# Contract is a proxy.
self._local_proxies[address_key] = proxy_info
return self.get(proxy_info.target, default=default)
implementation_contract_type = self.get(proxy_info.target, default=default)
proxy_contract_type = (
self._get_contract_type_from_explorer(address_key)
if fetch_from_explorer
else None
)
if proxy_contract_type:
contract_type_to_cache = _get_combined_contract_type(
proxy_contract_type, proxy_info, implementation_contract_type
)
else:
contract_type_to_cache = implementation_contract_type

self._local_contract_types[address_key] = contract_type_to_cache
self._cache_contract_to_disk(address_key, contract_type_to_cache)
return contract_type_to_cache

if not self.provider.get_code(address_key):
if default:
Expand Down Expand Up @@ -1271,6 +1298,7 @@ def instance_from_receipt(
Args:
receipt (:class:`~ape.api.transactions.ReceiptAPI`): The receipt.
contract_type (ContractType): The deployed contract type.
Returns:
:class:`~ape.contracts.base.ContractInstance`
Expand Down Expand Up @@ -1347,14 +1375,14 @@ def _get_proxy_info_from_disk(self, address: AddressType) -> Optional[ProxyInfoA
if not address_file.is_file():
return None

return ProxyInfoAPI.model_validate_json(address_file.read_text())
return ProxyInfoAPI.model_validate_json(address_file.read_text(encoding="utf8"))

def _get_blueprint_from_disk(self, blueprint_id: str) -> Optional[ContractType]:
contract_file = self._blueprint_cache / f"{blueprint_id}.json"
if not contract_file.is_file():
return None

return ContractType.model_validate_json(contract_file.read_text())
return ContractType.model_validate_json(contract_file.read_text(encoding="utf8"))

def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]:
if not self._network.explorer:
Expand Down Expand Up @@ -1766,3 +1794,23 @@ def get_receipt(self, transaction_hash: str) -> ReceiptAPI:
raise TransactionNotFoundError(transaction_hash=transaction_hash)

return receipt


def _get_combined_contract_type(
proxy_contract_type: ContractType,
proxy_info: ProxyInfoAPI,
implementation_contract_type: ContractType,
) -> ContractType:
proxy_abis = [
abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function")
]

# Include "hidden" ABIs, such as Safe's `masterCopy()`.
if proxy_info.abi and proxy_info.abi.signature not in [
abi.signature for abi in implementation_contract_type.abi
]:
proxy_abis.append(proxy_info.abi)

combined_contract_type = implementation_contract_type.model_copy(deep=True)
combined_contract_type.abi.extend(proxy_abis)
return combined_contract_type
10 changes: 5 additions & 5 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,7 @@ def get_proxy_info(self, address: AddressType) -> Optional[ProxyInfo]:
if isinstance(contract_code, bytes):
contract_code = to_hex(contract_code)

code = contract_code[2:]
if not code:
if not (code := contract_code[2:]):
return None

patterns = {
Expand Down Expand Up @@ -515,7 +514,7 @@ def str_to_slot(text):
if _type == ProxyType.Beacon:
target = ContractCall(IMPLEMENTATION_ABI, target)(skip_trace=True)

return ProxyInfo(type=_type, target=target)
return ProxyInfo(type=_type, target=target, abi=IMPLEMENTATION_ABI)

# safe >=1.1.0 provides `masterCopy()`, which is also stored in slot 0
# call it and check that target matches
Expand All @@ -525,7 +524,8 @@ def str_to_slot(text):
target = self.conversion_manager.convert(slot_0[-20:], AddressType)
# NOTE: `target` is set in initialized proxies
if target != ZERO_ADDRESS and target == singleton:
return ProxyInfo(type=ProxyType.GnosisSafe, target=target)
return ProxyInfo(type=ProxyType.GnosisSafe, target=target, abi=MASTER_COPY_ABI)

except ApeException:
pass

Expand All @@ -541,7 +541,7 @@ def str_to_slot(text):
target = ContractCall(IMPLEMENTATION_ABI, address)(skip_trace=True)
# avoid recursion
if target != ZERO_ADDRESS:
return ProxyInfo(type=ProxyType.Delegate, target=target)
return ProxyInfo(type=ProxyType.Delegate, target=target, abi=IMPLEMENTATION_ABI)

except (ApeException, ValueError):
pass
Expand Down
11 changes: 10 additions & 1 deletion src/ape_ethereum/proxies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum, auto
from typing import cast
from typing import Optional, cast

from eth_pydantic_types.hex import HexStr
from ethpm_types import ContractType, MethodABI
Expand Down Expand Up @@ -69,6 +69,15 @@ class ProxyType(IntEnum):
class ProxyInfo(ProxyInfoAPI):
type: ProxyType

def __init__(self, **kwargs):
abi = kwargs.pop("abi", None)
super().__init__(**kwargs)
self._abi = abi

@property
def abi(self) -> Optional[MethodABI]:
return self._abi


MASTER_COPY_ABI = MethodABI(
type="function",
Expand Down
6 changes: 6 additions & 0 deletions tests/functional/geth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from tests.functional.data.python import TRACE_RESPONSE


@pytest.fixture(scope="session")
def safe_proxy_container(get_contract_type):
proxy_type = get_contract_type("SafeProxy")
return ContractContainer(proxy_type)


@pytest.fixture
def parity_trace_response():
return TRACE_RESPONSE
Expand Down
55 changes: 55 additions & 0 deletions tests/functional/geth/test_contracts_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest

from ape.exceptions import ContractNotFoundError
from tests.conftest import geth_process_test


@geth_process_test
def test_get_proxy_from_explorer(
mock_explorer,
create_mock_sepolia,
safe_proxy_container,
geth_account,
vyper_contract_container,
geth_provider,
chain,
):
"""
Simulated when you get a contract from Etherscan for the first time
but that contract is a proxy. We expect both proxy and target ABIs
to be cached under the proxy's address.
"""
target_contract = geth_account.deploy(vyper_contract_container, 10011339315)
proxy_contract = geth_account.deploy(safe_proxy_container, target_contract.address)

# Ensure both of these are not cached so we have to rely on our fake explorer.
del chain.contracts[target_contract.address]
del chain.contracts[proxy_contract.address]
# Sanity check.
with pytest.raises(ContractNotFoundError):
_ = chain.contracts.instance_at(proxy_contract.address)

def get_contract_type(address, *args, **kwargs):
# Mock etherscan backend.
if address == target_contract.address:
return target_contract.contract_type
elif address == proxy_contract.address:
return proxy_contract.contract_type

raise ValueError("Fake explorer only knows about proxy and target contracts.")

with create_mock_sepolia() as network:
# Setup our network to use our fake explorer.
mock_explorer.get_contract_type.side_effect = get_contract_type
network.__dict__["explorer"] = mock_explorer

# Typical flow: user attempts to get an un-cached contract type from Etherscan.
# That contract may be a proxy, in which case we should get a type
# w/ both proxy ABIs and the target ABIs.
contract_from_explorer = chain.contracts.instance_at(proxy_contract.address)

# Ensure we can call proxy methods!
assert contract_from_explorer.masterCopy # No attr error!

# Ensure we can call target methods!
assert contract_from_explorer.myNumber # No attr error!
24 changes: 16 additions & 8 deletions tests/functional/geth/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,28 @@ def test_uups_proxy(get_contract_type, geth_contract, owner, ethereum):


@geth_process_test
def test_gnosis_safe(get_contract_type, geth_contract, owner, ethereum):
_type = get_contract_type("SafeProxy")
contract = ContractContainer(_type)

def test_gnosis_safe(safe_proxy_container, geth_contract, owner, ethereum, chain):
# Setup a proxy contract.
target = geth_contract.address
proxy_instance = owner.deploy(safe_proxy_container, target)

contract_instance = owner.deploy(contract, target)

actual = ethereum.get_proxy_info(contract_instance.address)

# (test)
actual = ethereum.get_proxy_info(proxy_instance.address)
assert actual is not None
assert actual.type == ProxyType.GnosisSafe
assert actual.target == target

# Ensure we can call the proxy-method.
assert proxy_instance.masterCopy()

# Ensure we can call target methods.
assert isinstance(proxy_instance.myNumber(), int)

# Ensure this works with new instances.
proxy_instance_ref_2 = chain.contracts.instance_at(proxy_instance.address)
assert proxy_instance_ref_2.masterCopy()
assert isinstance(proxy_instance_ref_2.myNumber(), int)


@geth_process_test
def test_openzeppelin(get_contract_type, geth_contract, owner, ethereum, sender):
Expand Down
24 changes: 19 additions & 5 deletions tests/functional/test_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,30 @@ def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, m
def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain):
target = vyper_contract_instance.address
proxy = owner.deploy(proxy_contract_container, target)

# Ensure we can call both proxy and target methods on it.
assert proxy.implementation # No attr err
assert proxy.myNumber # No attr err

# Ensure was properly cached.
assert proxy.address in chain.contracts._local_contract_types
assert proxy.address in chain.contracts._local_proxies

actual = chain.contracts._local_proxies[proxy.address]
assert actual.target == target
assert actual.type == ProxyType.Delegate
# Show the cached proxy info is correct.
proxy_info = chain.contracts._local_proxies[proxy.address]
assert proxy_info.target == target
assert proxy_info.type == ProxyType.Delegate
assert proxy_info.abi.name == "implementation"

# Show we get the implementation contract type using the proxy address
implementation = chain.contracts.instance_at(proxy.address)
assert implementation.contract_type == vyper_contract_instance.contract_type
re_contract = chain.contracts.instance_at(proxy.address)
assert re_contract.contract_type == proxy.contract_type

# Show proxy methods are not available on target alone.
target = chain.contracts.instance_at(proxy_info.target)
assert target.myNumber # No attr err
with pytest.raises(AttributeError):
_ = target.implementation


def test_deploy_instance(owner, vyper_contract_instance):
Expand Down
Loading

0 comments on commit fdbbb07

Please sign in to comment.