diff --git a/docs/userguides/console.md b/docs/userguides/console.md index 65d5fa50d7..2badd72fa4 100644 --- a/docs/userguides/console.md +++ b/docs/userguides/console.md @@ -74,7 +74,7 @@ If you include a function named `ape_init_extras`, it will be executed with the ```python def ape_init_extras(chain): - return {"web3": chain.provider._web3} + return {"web3": chain.provider.web3} ``` Then `web3` will be available to use immediately. diff --git a/docs/userguides/contracts.md b/docs/userguides/contracts.md index d75dd23e73..b22277af81 100644 --- a/docs/userguides/contracts.md +++ b/docs/userguides/contracts.md @@ -224,7 +224,7 @@ In the example above, the bytes value returned contains the method ID selector p Alternatively, you can decode input: ```python -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape import Contract contract = Contract("0x...") diff --git a/docs/userguides/developing_plugins.md b/docs/userguides/developing_plugins.md index 8d3a364384..906454c513 100644 --- a/docs/userguides/developing_plugins.md +++ b/docs/userguides/developing_plugins.md @@ -39,7 +39,7 @@ class MyProvider(ProviderAPI): _web3: Web3 = None # type: ignore def connect(self): - self._web3 = Web3(HTTPProvider(str("https://localhost:1337"))) + self.cached_web3 = Web3(HTTPProvider(str("https://localhost:1337"))) """Implement rest of abstract methods""" ``` diff --git a/setup.py b/setup.py index cb9e412881..463bdd7fa2 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,8 @@ "packaging>=23.0,<24", "pandas>=1.3.0,<2", "pluggy>=1.3,<2", - "pydantic>=1.10.8,<3", + "pydantic>=2.4.0,<3", + "pydantic-settings>=2.0.3,<3", "PyGithub>=1.59,<2", "pytest>=6.0,<8.0", "python-dateutil>=2.8.2,<3", @@ -124,8 +125,8 @@ "web3[tester]>=6.7.0,<7", # ** Dependencies maintained by ApeWorX ** "eip712>=0.2.1,<0.3", - "ethpm-types>=0.5.10,<0.6", - "evm-trace>=0.1.0a23", + "ethpm-types>=0.6.0,<0.7", + "evm-trace>=0.1.0a26", ], entry_points={ "console_scripts": ["ape=ape._cli:cli"], diff --git a/src/ape/_cli.py b/src/ape/_cli.py index f51778e1cd..82c9985bbb 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -23,7 +23,7 @@ def display_config(ctx, param, value): from ape import project click.echo("# Current configuration") - click.echo(yaml.dump(project.config_manager.dict())) + click.echo(yaml.dump(project.config_manager.model_dump(mode="json"))) ctx.exit() # NOTE: Must exit to bypass running ApeCLI diff --git a/src/ape/_pydantic_compat.py b/src/ape/_pydantic_compat.py deleted file mode 100644 index 16bfe4b3a8..0000000000 --- a/src/ape/_pydantic_compat.py +++ /dev/null @@ -1,47 +0,0 @@ -# support both pydantic v1 and v2 - -try: - from pydantic.v1 import ( # type: ignore - BaseModel, - BaseSettings, - Extra, - Field, - FileUrl, - HttpUrl, - NonNegativeInt, - PositiveInt, - ValidationError, - root_validator, - validator, - ) - from pydantic.v1.dataclasses import dataclass # type: ignore -except ImportError: - from pydantic import ( # type: ignore - BaseModel, - BaseSettings, - Extra, - Field, - FileUrl, - HttpUrl, - NonNegativeInt, - PositiveInt, - ValidationError, - root_validator, - validator, - ) - from pydantic.dataclasses import dataclass # type: ignore - -__all__ = ( - "BaseModel", - "BaseSettings", - "dataclass", - "Extra", - "Field", - "FileUrl", - "HttpUrl", - "NonNegativeInt", - "PositiveInt", - "ValidationError", - "root_validator", - "validator", -) diff --git a/src/ape/api/address.py b/src/ape/api/address.py index 9a14195bac..8f59323e98 100644 --- a/src/ape/api/address.py +++ b/src/ape/api/address.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, List -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape.exceptions import ConversionError from ape.types import AddressType, ContractCode diff --git a/src/ape/api/compiler.py b/src/ape/api/compiler.py index 17b1c4f311..73d80ada52 100644 --- a/src/ape/api/compiler.py +++ b/src/ape/api/compiler.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Dict, Iterator, List, Optional, Set, Tuple -from ethpm_types import ContractType, HexBytes +from eth_pydantic_types import HexBytes +from ethpm_types import ContractType from ethpm_types.source import Content, ContractSource from evm_trace.geth import TraceFrame as EvmTraceFrame from evm_trace.geth import create_call_node_data diff --git a/src/ape/api/config.py b/src/ape/api/config.py index b08a6a8086..fe4264c3ef 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Any, Dict, Optional, TypeVar -from ape._pydantic_compat import BaseModel, BaseSettings +from pydantic import ConfigDict +from pydantic_settings import BaseSettings T = TypeVar("T") @@ -14,10 +15,6 @@ class ConfigEnum(str, Enum): """ -class ConfigDict(BaseModel): - __root__: dict = {} - - class PluginConfig(BaseSettings): """ A base plugin configuration class. Each plugin that includes @@ -26,7 +23,7 @@ class PluginConfig(BaseSettings): @classmethod def from_overrides(cls, overrides: Dict) -> "PluginConfig": - default_values = cls().dict() + default_values = cls().model_dump() def update(root: Dict, value_map: Dict): for key, val in value_map.items(): @@ -54,9 +51,7 @@ def get(self, key: str, default: Optional[T] = None) -> T: return self.__dict__.get(key, default) -class GenericConfig(PluginConfig): +class GenericConfig(ConfigDict): """ The default class used when no specialized class is used. """ - - __root__: dict = {} diff --git a/src/ape/api/convert.py b/src/ape/api/convert.py index adbb676930..e2048badfb 100644 --- a/src/ape/api/convert.py +++ b/src/ape/api/convert.py @@ -5,7 +5,7 @@ ConvertedType = TypeVar("ConvertedType") -class ConverterAPI(Generic[ConvertedType], BaseInterfaceModel): +class ConverterAPI(BaseInterfaceModel, Generic[ConvertedType]): @abstractmethod def is_convertible(self, value: Any) -> bool: """ diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index c37386ffd4..63869e8e45 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -19,11 +19,11 @@ encode_transaction, serializable_unsigned_transaction_from_dict, ) +from eth_pydantic_types import HexBytes from eth_utils import keccak, to_int -from ethpm_types import ContractType, HexBytes +from ethpm_types import BaseModel, ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI -from ape._pydantic_compat import BaseModel from ape.exceptions import ( NetworkError, NetworkMismatchError, @@ -86,7 +86,8 @@ class EcosystemAPI(BaseInterfaceModel): fee_token_decimals: int = 18 """The number of the decimals the fee token has.""" - _default_network: Optional[str] = None + cached_default_network: Optional[str] = None + """The default network of the ecosystem, such as ``local``.""" def __repr__(self) -> str: return f"<{self.name}>" @@ -151,8 +152,7 @@ def serialize_transaction(self, transaction: "TransactionAPI") -> bytes: if not self.signature: raise SignatureError("The transaction is not signed.") - txn_data = self.dict(exclude={"sender"}) - + txn_data = self.model_dump(exclude={"sender"}) unsigned_txn = serializable_unsigned_transaction_from_dict(txn_data) signature = ( self.signature.v, @@ -267,7 +267,7 @@ def default_network(self) -> str: str """ - if network := self._default_network: + if network := self.cached_default_network: # Was set programatically. return network @@ -284,7 +284,7 @@ def default_network(self) -> str: return self.networks[0] # Very unlikely scenario. - raise ValueError("No networks found.") + raise NetworkError("No networks found.") def set_default_network(self, network_name: str): """ @@ -505,7 +505,7 @@ def get_method_selector(self, abi: MethodABI) -> HexBytes: Override example:: from ape.api import EcosystemAPI - from ethpm_types import HexBytes + from eth_pydantic_types import HexBytes class MyEcosystem(EcosystemAPI): def get_method_selector(self, abi: MethodABI) -> HexBytes: diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index 297fbcde87..4b63250c64 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -5,12 +5,11 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from ethpm_types import Checksum, ContractType, PackageManifest, Source -from ethpm_types.manifest import PackageName from ethpm_types.source import Content -from ethpm_types.utils import Algorithm, AnyUrl, compute_checksum +from ethpm_types.utils import Algorithm, compute_checksum from packaging.version import InvalidVersion, Version +from pydantic import AnyUrl, ValidationError -from ape._pydantic_compat import ValidationError from ape.exceptions import ProjectError from ape.logging import logger from ape.utils import ( @@ -112,7 +111,7 @@ def cached_manifest(self) -> Optional[PackageManifest]: continue path = self._cache_folder / f"{contract_type.name}.json" - path.write_text(contract_type.json()) + path.write_text(contract_type.model_dump_json()) # Rely on individual cache files. self._contracts = manifest.contract_types @@ -147,10 +146,8 @@ def contracts(self) -> Dict[str, ContractType]: continue contract_name = p.stem - contract_type = ContractType.parse_file(p) - if contract_type.name is None: - contract_type.name = contract_name - + contract_type = ContractType.model_validate_json(p.read_text()) + contract_type.name = contract_name if contract_type.name is None else contract_type.name contracts[contract_type.name] = contract_type self._contracts = contracts @@ -182,7 +179,7 @@ def _create_manifest( initial_manifest: Optional[PackageManifest] = None, ) -> PackageManifest: manifest = initial_manifest or PackageManifest() - manifest.name = PackageName(__root__=name.lower()) if name is not None else manifest.name + manifest.name = name.lower() if name is not None else manifest.name manifest.version = version or manifest.version manifest.sources = cls._create_source_dict(source_paths, contracts_path) manifest.contract_types = contract_types @@ -219,7 +216,7 @@ def _create_source_dict( hash=compute_checksum(source_path.read_bytes()), ), urls=[], - content=Content(__root__={i + 1: x for i, x in enumerate(text.splitlines())}), + content=Content(root={i + 1: x for i, x in enumerate(text.splitlines())}), imports=source_imports.get(key, []), references=source_references.get(key, []), ) @@ -470,7 +467,7 @@ def _get_sources(self, project: ProjectAPI) -> List[Path]: def _write_manifest_to_cache(self, manifest: PackageManifest): self._target_manifest_cache_file.unlink(missing_ok=True) self._target_manifest_cache_file.parent.mkdir(exist_ok=True, parents=True) - self._target_manifest_cache_file.write_text(manifest.json()) + self._target_manifest_cache_file.write_text(manifest.model_dump_json()) self._cached_manifest = manifest @@ -479,7 +476,7 @@ def _load_manifest_from_file(file_path: Path) -> Optional[PackageManifest]: return None try: - return PackageManifest.parse_file(file_path) + return PackageManifest.model_validate_json(file_path.read_text()) except ValidationError as err: logger.warning(f"Existing manifest file '{file_path}' corrupted. Re-building.") logger.debug(str(err)) diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index cb7fe5bfac..78a84d0ae2 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -16,17 +16,17 @@ from subprocess import DEVNULL, PIPE, Popen from typing import Any, Dict, Iterator, List, Optional, Union, cast +from eth_pydantic_types import HexBytes from eth_typing import BlockNumber, HexStr from eth_utils import add_0x_prefix, to_hex -from ethpm_types import HexBytes from evm_trace import CallTreeNode as EvmCallTreeNode from evm_trace import TraceFrame as EvmTraceFrame +from pydantic import Field, model_validator from web3 import Web3 from web3.exceptions import ContractLogicError as Web3ContractLogicError from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound from web3.types import RPCEndpoint, TxParams -from ape._pydantic_compat import Field, root_validator, validator from ape.api.config import PluginConfig from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI from ape.api.query import BlockTransactionQuery @@ -95,20 +95,12 @@ class BlockAPI(BaseInterfaceModel): def datetime(self) -> datetime.datetime: return datetime.datetime.fromtimestamp(self.timestamp, tz=datetime.timezone.utc) - @root_validator(pre=True) + @model_validator(mode="before") def convert_parent_hash(cls, data): parent_hash = data.get("parent_hash", data.get("parentHash")) or EMPTY_BYTES32 data["parentHash"] = parent_hash return data - @validator("hash", "parent_hash", pre=True) - def validate_hexbytes(cls, value): - # NOTE: pydantic treats these values as bytes and throws an error - if value and not isinstance(value, bytes): - return HexBytes(value) - - return value - @cached_property def transactions(self) -> List[TransactionAPI]: query = BlockTransactionQuery(columns=["*"], block_id=self.hash) @@ -134,10 +126,10 @@ class ProviderAPI(BaseInterfaceModel): data_folder: Path """The path to the ``.ape`` directory.""" - request_header: dict + request_header: Dict """A header to set on HTTP/RPC requests.""" - cached_chain_id: Optional[int] = None + cached_chain_id: Optional[int] = Field(None, exclude=True) """Implementation providers may use this to cache and re-use chain ID.""" block_page_size: int = 100 @@ -762,8 +754,8 @@ class Web3Provider(ProviderAPI, ABC): `web3.py `__ python package. """ - _web3: Optional[Web3] = None - _client_version: Optional[str] = None + cached_web3: Optional[Web3] = None + cached_client_version: Optional[str] = None def __init__(self, *args, **kwargs): logger.create_logger("web3.RequestManager", handlers=(_sanitize_web3_url,)) @@ -776,10 +768,10 @@ def web3(self) -> Web3: Access to the ``web3`` object as if you did ``Web3(HTTPProvider(uri))``. """ - if not self._web3: + if not self.cached_web3: raise ProviderNotConnectedError() - return self._web3 + return self.cached_web3 @property def http_uri(self) -> Optional[str]: @@ -809,14 +801,14 @@ def ws_uri(self) -> Optional[str]: @property def client_version(self) -> str: - if not self._web3: + if not self.cached_web3: return "" # NOTE: Gets reset to `None` on `connect()` and `disconnect()`. - if self._client_version is None: - self._client_version = self.web3.client_version + if self.cached_client_version is None: + self.cached_client_version = self.web3.client_version - return self._client_version + return self.cached_client_version @property def base_fee(self) -> int: @@ -857,10 +849,10 @@ def _get_last_base_fee(self) -> int: @property def is_connected(self) -> bool: - if self._web3 is None: + if self.cached_web3 is None: return False - return run_until_complete(self._web3.is_connected()) + return run_until_complete(self.cached_web3.is_connected()) @property def max_gas(self) -> int: @@ -909,7 +901,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: return the block maximum gas limit. """ - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") # Force the use of hex values to support a wider range of nodes. if isinstance(txn_dict.get("type"), int): @@ -1195,7 +1187,7 @@ def _eth_call(self, arguments: List) -> bytes: return HexBytes(result) def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List: - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") fields_to_convert = ("data", "chainId", "value") for field in fields_to_convert: value = txn_dict.get(field) @@ -1258,7 +1250,9 @@ def get_receipt( except TimeExhausted as err: raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err - network_config: Dict = self.network.config.dict().get(self.network.name, {}) + network_config: Dict = self.network.config.model_dump(mode="json").get( + self.network.name, {} + ) max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) txn = {} for attempt in range(max_retries): @@ -1430,10 +1424,10 @@ def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: def fetch_log_page(block_range): start, stop = block_range - page_filter = log_filter.copy(update=dict(start_block=start, stop_block=stop)) + page_filter = log_filter.model_copy(update=dict(start_block=start, stop_block=stop)) # eth-tester expects a different format, let web3 handle the conversions for it raw = "EthereumTester" not in self.client_version - logs = self._get_logs(page_filter.dict(), raw) + logs = self._get_logs(page_filter.model_dump(mode="json"), raw) return self.network.ecosystem.decode_logs(logs, *log_filter.events) with ThreadPoolExecutor(self.concurrency) as pool: @@ -1503,7 +1497,8 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: if txn.sender not in self.web3.eth.accounts: self.chain_manager.provider.unlock_account(txn.sender) - txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn.dict())) + txn_data = cast(TxParams, txn.model_dump(mode="json")) + txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn_data)) except (ValueError, Web3ContractLogicError) as err: vm_err = self.get_virtual_machine_error(err, txn=txn) @@ -1522,7 +1517,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: self.chain_manager.history.append(receipt) if receipt.failed: - txn_dict = receipt.transaction.dict() + txn_dict = receipt.transaction.model_dump(mode="json") txn_params = cast(TxParams, txn_dict) # Replay txn to get revert reason @@ -1556,7 +1551,7 @@ def _create_call_tree_node( inputs=evm_call.calldata if "CREATE" in call_type else evm_call.calldata[4:].hex(), method_id=evm_call.calldata[:4].hex(), outputs=evm_call.returndata.hex(), - raw=evm_call.dict(), + raw=evm_call.model_dump(mode="json"), txn_hash=txn_hash, ) @@ -1579,7 +1574,7 @@ def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: gas_cost=evm_frame.gas_cost, depth=evm_frame.depth, contract_address=address, - raw=evm_frame.dict(), + raw=evm_frame.model_dump(mode="json"), ) def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any: @@ -1809,7 +1804,10 @@ def disconnect(self): Subclasses override this method to do provider-specific disconnection tasks. """ - self.cached_chain_id = None + # NOTE: Setting it this way mostly because of a mypy issue. + default_value = self.model_fields["cached_chain_id"].default + self.cached_chain_id = default_value + if self.process: self.stop() diff --git a/src/ape/api/query.py b/src/ape/api/query.py index 3067d505bd..5b9ae8ced4 100644 --- a/src/ape/api/query.py +++ b/src/ape/api/query.py @@ -1,9 +1,9 @@ from functools import lru_cache from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Type, Union -from ethpm_types.abi import EventABI, MethodABI +from ethpm_types.abi import BaseModel, EventABI, MethodABI +from pydantic import NonNegativeInt, PositiveInt, model_validator -from ape._pydantic_compat import BaseModel, NonNegativeInt, PositiveInt, root_validator from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.logging import logger from ape.types import AddressType @@ -22,7 +22,7 @@ # TODO: Replace with `functools.cache` when Py3.8 dropped @lru_cache(maxsize=None) def _basic_columns(Model: Type[BaseInterfaceModel]) -> Set[str]: - columns = set(Model.__fields__) + columns = set(Model.model_fields) # TODO: Remove once `ReceiptAPI` fields cleaned up for better processing if Model == ReceiptAPI: @@ -104,7 +104,7 @@ class _BaseBlockQuery(_BaseQuery): stop_block: NonNegativeInt step: PositiveInt = 1 - @root_validator(pre=True) + @model_validator(mode="before") def check_start_block_before_stop_block(cls, values): if values["stop_block"] < values["start_block"]: raise ValueError( @@ -141,7 +141,7 @@ class AccountTransactionQuery(_BaseQuery): start_nonce: NonNegativeInt = 0 stop_nonce: NonNegativeInt - @root_validator(pre=True) + @model_validator(mode="before") def check_start_nonce_before_stop_nonce(cls, values: Dict) -> Dict: if values["stop_nonce"] < values["start_nonce"]: raise ValueError( diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 71b82fe805..94fc4094d6 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -3,12 +3,13 @@ from datetime import datetime from typing import IO, TYPE_CHECKING, Any, Iterator, List, Optional, Union +from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed, is_hex, to_int -from ethpm_types import HexBytes from ethpm_types.abi import EventABI, MethodABI +from pydantic import ConfigDict, field_validator +from pydantic.fields import Field from tqdm import tqdm # type: ignore -from ape._pydantic_compat import Field, validator from ape.api.explorers import ExplorerAPI from ape.exceptions import ( NetworkError, @@ -53,7 +54,7 @@ class TransactionAPI(BaseInterfaceModel): gas_limit: Optional[int] = Field(None, alias="gas") nonce: Optional[int] = None # NOTE: `Optional` only to denote using default behavior value: int = 0 - data: bytes = b"" + data: HexBytes = HexBytes("") type: int max_fee: Optional[int] = None max_priority_fee: Optional[int] = None @@ -63,10 +64,9 @@ class TransactionAPI(BaseInterfaceModel): signature: Optional[TransactionSignature] = Field(None, exclude=True) - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) - @validator("gas_limit", pre=True, allow_reuse=True) + @field_validator("gas_limit", mode="before") def validate_gas_limit(cls, value): if value is None: if not cls.network_manager.active_provider: @@ -91,21 +91,21 @@ def validate_gas_limit(cls, value): return value - @validator("max_fee", "max_priority_fee", pre=True, allow_reuse=True) + @field_validator("max_fee", "max_priority_fee", mode="before") def convert_fees(cls, value): if isinstance(value, str): return cls.conversion_manager.convert(value, int) return value - @validator("data", pre=True) + @field_validator("data", mode="before") def validate_data(cls, value): if isinstance(value, str): return HexBytes(value) return value - @validator("value", pre=True) + @field_validator("value", mode="before") def validate_value(cls, value): if isinstance(value, int): return value @@ -169,12 +169,12 @@ def serialize_transaction(self) -> bytes: """ def __repr__(self) -> str: - data = self.dict() + data = self.model_dump(mode="json") params = ", ".join(f"{k}={v}" for k, v in data.items()) return f"<{self.__class__.__name__} {params}>" def __str__(self) -> str: - data = self.dict() + data = self.model_dump(mode="json") if len(data["data"]) > 9: # only want to specify encoding if data["data"] is a string if isinstance(data["data"], str): @@ -267,16 +267,16 @@ def __repr__(self) -> str: return f"<{self.__class__.__name__} {self.txn_hash}>" def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: - yield ExtraModelAttributes(name="transaction", attributes=self.transaction) + yield ExtraModelAttributes(name="transaction", attributes=vars(self.transaction)) - @validator("transaction", pre=True) + @field_validator("transaction", mode="before") def confirm_transaction(cls, value): if isinstance(value, dict): - value = TransactionAPI.parse_obj(value) + value = TransactionAPI.model_validate(value) return value - @validator("txn_hash", pre=True) + @field_validator("txn_hash", mode="before") def validate_txn_hash(cls, value): return HexBytes(value).hex() @@ -472,7 +472,7 @@ def return_value(self) -> Any: @raises_not_implemented def source_traceback(self) -> SourceTraceback: # type: ignore[empty-body] """ - A pythonic style traceback for both failing and non-failing receipts. + A Pythonic style traceback for both failing and non-failing receipts. Requires a provider that implements :meth:~ape.api.providers.ProviderAPI.get_transaction_trace`. """ diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 8d27f7d26a..e8eb000c3c 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -7,7 +7,8 @@ import click import pandas as pd -from ethpm_types import ContractType, HexBytes +from eth_pydantic_types import HexBytes +from ethpm_types import ContractType from ethpm_types.abi import ConstructorABI, ErrorABI, EventABI, MethodABI from ape.api import AccountAPI, Address, ReceiptAPI, TransactionAPI @@ -554,12 +555,8 @@ def query( f"'stop={stop_block}' cannot be greater than " f"the chain length ({self.chain_manager.blocks.height})." ) - - if columns[0] == "*": - columns = list(ContractLog.__fields__) # type: ignore - query: Dict = { - "columns": columns, + "columns": list(ContractLog.model_fields) if columns[0] == "*" else columns, "event": self.abi, "start_block": start_block, "stop_block": stop_block, @@ -631,7 +628,7 @@ def range( addresses = list(set([contract_address] + (extra_addresses or []))) contract_event_query = ContractEventQuery( - columns=list(ContractLog.__fields__.keys()), + columns=list(ContractLog.model_fields.keys()), contract=addresses, event=self.abi, search_topics=search_topics, diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index 42fcaad0d2..527e0ad918 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -223,7 +223,7 @@ def range( # Note: the range `stop_block` is a non-inclusive stop, while the # `.query` method uses an inclusive stop, so we must adjust downwards. query = BlockQuery( - columns=list(self.head.__fields__), # TODO: fetch the block fields from EcosystemAPI + columns=list(self.head.model_fields), # TODO: fetch the block fields from EcosystemAPI start_block=start, stop_block=stop - 1, step=step, @@ -516,7 +516,7 @@ def __getitem_int(self, index: int) -> ReceiptAPI: next( self.query_manager.query( AccountTransactionQuery( - columns=list(ReceiptAPI.__fields__), + columns=list(ReceiptAPI.model_fields), account=self.address, start_nonce=index, stop_nonce=index, @@ -554,7 +554,7 @@ def __getitem_slice(self, indices: slice) -> List[ReceiptAPI]: list( self.query_manager.query( AccountTransactionQuery( - columns=list(ReceiptAPI.__fields__), + columns=list(ReceiptAPI.model_fields), account=self.address, start_nonce=start, stop_nonce=stop - 1, @@ -1335,21 +1335,21 @@ def _get_contract_type_from_disk(self, address: AddressType) -> Optional[Contrac if not address_file.is_file(): return None - return ContractType.parse_file(address_file) + return ContractType.model_validate_json(address_file.read_text()) def _get_proxy_info_from_disk(self, address: AddressType) -> Optional[ProxyInfoAPI]: address_file = self._proxy_info_cache / f"{address}.json" if not address_file.is_file(): return None - return ProxyInfoAPI.parse_file(address_file) + return ProxyInfoAPI.model_validate_json(address_file.read_text()) def _get_blueprint_from_disk(self, blueprint_id: str) -> Optional[ContractType]: contract_file = self._blueprint_cache / f"{blueprint_id}.json" if not contract_file.is_file(): return None - return ContractType.parse_file(contract_file) + return ContractType.model_validate_json(contract_file.read_text()) def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: if not self._network.explorer: @@ -1370,17 +1370,17 @@ def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[Con def _cache_contract_to_disk(self, address: AddressType, contract_type: ContractType): self._contract_types_cache.mkdir(exist_ok=True, parents=True) address_file = self._contract_types_cache / f"{address}.json" - address_file.write_text(contract_type.json()) + address_file.write_text(contract_type.model_dump_json()) def _cache_proxy_info_to_disk(self, address: AddressType, proxy_info: ProxyInfoAPI): self._proxy_info_cache.mkdir(exist_ok=True, parents=True) address_file = self._proxy_info_cache / f"{address}.json" - address_file.write_text(proxy_info.json()) + address_file.write_text(proxy_info.model_dump_json()) def _cache_blueprint_to_disk(self, blueprint_id: str, contract_type: ContractType): self._blueprint_cache.mkdir(exist_ok=True, parents=True) blueprint_file = self._blueprint_cache / f"{blueprint_id}.json" - blueprint_file.write_text(contract_type.json()) + blueprint_file.write_text(contract_type.model_dump_json()) def _load_deployments_cache(self) -> Dict: return ( diff --git a/src/ape/managers/compilers.py b/src/ape/managers/compilers.py index ca4b7c557f..cc8cf07d86 100644 --- a/src/ape/managers/compilers.py +++ b/src/ape/managers/compilers.py @@ -62,7 +62,7 @@ def registered_compilers(self) -> Dict[str, CompilerAPI]: for plugin_name, (extensions, compiler_class) in self.plugin_manager.register_compiler: # TODO: Investigate side effects of loading compiler plugins. # See if this needs to be refactored. - self.config_manager.get_config(plugin_name=plugin_name) + self.config_manager.get_config(plugin_name) compiler = compiler_class() @@ -80,7 +80,7 @@ def get_compiler(self, name: str, settings: Optional[Dict] = None) -> Optional[C if settings is not None and settings != compiler.compiler_settings: # Use a new instance to support multiple compilers of same type. - return compiler.copy(update={"compiler_settings": settings}) + return compiler.model_copy(update={"compiler_settings": settings}) return compiler @@ -179,13 +179,15 @@ def compile( ) try: - existing_contract = ContractType.parse_file(existing_artifact) + existing_contract = ContractType.model_validate_json( + existing_artifact.read_text() + ) except Exception: existing_artifact.unlink() else: - if existing_contract.source_id: - path = self.project_manager.lookup_path(existing_contract.source_id) + if existing_id := existing_contract.source_id: + path = self.project_manager.lookup_path(existing_id) if path and existing_contract.source_id != contract_type.source_id: error_message = ( f"{ContractType.__name__} collision '{contract_name}'." @@ -196,6 +198,10 @@ def compile( # Artifact remaining from deleted contract, can delete. existing_artifact.unlink() + else: + # Is probably invalid and will be replaced by an artifact with an ID. + existing_artifact.unlink() + contract_types_dict[contract_name] = contract_type return contract_types_dict diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index ba10efc5de..c5bd45b7e7 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -3,7 +3,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union -from ape._pydantic_compat import root_validator +from ethpm_types import PackageMeta +from pydantic import RootModel, model_validator + from ape.api import ConfigDict, DependencyAPI, PluginConfig from ape.exceptions import ConfigError from ape.logging import logger @@ -12,7 +14,6 @@ if TYPE_CHECKING: from .project import ProjectManager -from ethpm_types import BaseModel, PackageMeta CONFIG_FILE_NAME = "ape-config.yaml" @@ -27,16 +28,13 @@ class CompilerConfig(PluginConfig): """List of globular files to ignore""" -class DeploymentConfigCollection(BaseModel): - __root__: Dict - - @root_validator(pre=True) - def validate_deployments(cls, data: Dict): - root_data = data.get("__root__", data) - valid_ecosystems = root_data.pop("valid_ecosystems", {}) - valid_networks = root_data.pop("valid_networks", {}) +class DeploymentConfigCollection(RootModel[dict]): + @model_validator(mode="before") + def validate_deployments(cls, data: Dict, info): + valid_ecosystems = data.pop("valid_ecosystems", {}) + valid_networks = data.pop("valid_networks", {}) valid_data: Dict = {} - for ecosystem_name, networks in root_data.items(): + for ecosystem_name, networks in data.items(): if ecosystem_name not in valid_ecosystems: logger.warning(f"Invalid ecosystem '{ecosystem_name}' in deployments config.") continue @@ -69,7 +67,7 @@ def validate_deployments(cls, data: Dict): network_name: valid_deployments, } - return {"__root__": valid_data} + return valid_data class ConfigManager(BaseInterfaceModel): @@ -126,11 +124,11 @@ class ConfigManager(BaseInterfaceModel): default_ecosystem: str = "ethereum" """The default ecosystem to use. Defaults to ``"ethereum"``.""" - _cached_configs: Dict[str, Dict[str, Any]] = {} + cached_configs: Dict[str, Dict[str, Any]] = {} - @root_validator(pre=True) + @model_validator(mode="before") def check_config_for_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: - extra = [key for key in values.keys() if key not in cls.__fields__] + extra = [key for key in values.keys() if key not in cls.model_fields] if extra: logger.warning(f"Unprocessed extra config fields not set '{extra}'.") @@ -144,16 +142,16 @@ def packages_folder(self) -> Path: @property def _plugin_configs(self) -> Dict[str, PluginConfig]: project_name = self.PROJECT_FOLDER.stem - if project_name in self._cached_configs: - cache = self._cached_configs[project_name] + if project_name in self.cached_configs: + cache = self.cached_configs[project_name] self.name = cache.get("name", "") self.version = cache.get("version", "") self.default_ecosystem = cache.get("default_ecosystem", "ethereum") - self.meta = PackageMeta.parse_obj(cache.get("meta", {})) + self.meta = PackageMeta.model_validate(cache.get("meta", {})) self.dependencies = cache.get("dependencies", []) self.deployments = cache.get("deployments", {}) self.contracts_folder = cache.get("contracts_folder", self.PROJECT_FOLDER / "contracts") - self.compiler = CompilerConfig.parse_obj(cache.get("compiler", {})) + self.compiler = CompilerConfig.model_validate(cache.get("compiler", {})) return cache # First, load top-level configs. Then, load all the plugin configs. @@ -172,13 +170,13 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: self.name = configs["name"] = user_config.pop("name", "") self.version = configs["version"] = user_config.pop("version", "") meta_dict = user_config.pop("meta", {}) - meta_obj = PackageMeta.parse_obj(meta_dict) + meta_obj = PackageMeta.model_validate(meta_dict) configs["meta"] = meta_dict self.meta = meta_obj self.default_ecosystem = configs["default_ecosystem"] = user_config.pop( "default_ecosystem", "ethereum" ) - compiler_dict = user_config.pop("compiler", CompilerConfig().dict()) + compiler_dict = user_config.pop("compiler", CompilerConfig().model_dump(mode="json")) configs["compiler"] = compiler_dict self.compiler = CompilerConfig(**compiler_dict) @@ -203,7 +201,7 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: valid_ecosystems = dict(self.plugin_manager.ecosystems) valid_network_names = [n[1] for n in [e[1] for e in self.plugin_manager.networks]] self.deployments = configs["deployments"] = DeploymentConfigCollection( - __root__={ + root={ **deployments, "valid_ecosystems": valid_ecosystems, "valid_networks": valid_network_names, @@ -230,7 +228,7 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: "Plugins may not be installed yet or keys may be mis-spelled." ) - self._cached_configs[project_name] = configs + self.cached_configs[project_name] = configs return configs def __repr__(self): @@ -242,7 +240,7 @@ def load(self, force_reload: bool = False) -> "ConfigManager": """ if force_reload: - self._cached_configs = {} + self.cached_configs = {} _ = self._plugin_configs return self diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index 199c70ca9f..46adbc05a7 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -3,6 +3,8 @@ from typing import Any, Dict, List, Sequence, Tuple, Type, Union from dateutil.parser import parse # type: ignore +from eth_pydantic_types import HexBytes +from eth_typing.evm import ChecksumAddress from eth_utils import ( is_0x_prefixed, is_checksum_address, @@ -12,7 +14,7 @@ to_hex, to_int, ) -from ethpm_types import ConstructorABI, EventABI, HexBytes, MethodABI +from ethpm_types import ConstructorABI, EventABI, MethodABI from ape.api import ConverterAPI, TransactionAPI from ape.api.address import BaseAddress @@ -116,7 +118,7 @@ def convert(self, value: str) -> AddressType: ``AddressType`` """ - return to_checksum_address(value) + return AddressType(to_checksum_address(value)) class BytesAddressConverter(ConverterAPI): @@ -270,7 +272,7 @@ def _converters(self) -> Dict[Type, List[ConverterAPI]]: for plugin_name, (conversion_type, converter_class) in self.plugin_manager.converters: converter = converter_class() if conversion_type not in converters: - options = ", ".join([t.__name__ for t in converters]) + options = ", ".join([_get_type_name_from_type(t) for t in converters]) raise ConversionError(f"Type '{conversion_type}' must be one of [{options}].") converters[conversion_type].append(converter) @@ -329,7 +331,7 @@ def convert(self, value: Any, type: Union[Type, Tuple, List]) -> Any: ) elif type not in self._converters: - options = ", ".join([t.__name__ for t in self._converters]) + options = ", ".join([_get_type_name_from_type(t) for t in self._converters]) raise ConversionError(f"Type '{type}' must be one of [{options}].") elif self.is_type(value, type) and not isinstance(value, (list, tuple)): @@ -372,11 +374,49 @@ def convert_method_args( return self.convert(pre_processed_args, tuple) def convert_method_kwargs(self, kwargs) -> Dict: - fields = TransactionAPI.__fields__ + fields = TransactionAPI.model_fields + + def get_real_type(type_): + all_types = getattr(type_, "_typevar_types", []) + if not all_types or not isinstance(all_types, (list, tuple)): + return type_ + + # Filter out None + valid_types = [t for t in all_types if t is not None] + if len(valid_types) == 1: + # This is something like Optional[int], + # however, if the user provides a value, + # we want to convert to the non-optional type. + return valid_types[0] + + # Not sure if this is possible; the converter may fail. + return valid_types + annotations = {name: get_real_type(f.annotation) for name, f in fields.items()} kwargs_to_convert = {k: v for k, v in kwargs.items() if k == "sender" or k in fields} - converted_fields = { - k: self.convert(v, AddressType if k == "sender" else fields[k].type_) - for k, v in kwargs_to_convert.items() - } + converted_fields = {} + for field_name, value in kwargs_to_convert.items(): + type_ = AddressType if field_name == "sender" else annotations.get(field_name) + if type_: + try: + converted_value = self.convert(value, type_) + except ConversionError: + # Ignore conversion errors and use the values as-is. + converted_value = value + + else: + converted_value = value + + converted_fields[field_name] = converted_value + return {**kwargs, **converted_fields} + + +def _get_type_name_from_type(var_type: Type) -> str: + if var_type.__name__ == "Annotated" and (real_types := getattr(var_type, "__args__", None)): + # Is a NewType + result = real_types[0].__name__ + if result == ChecksumAddress.__name__: + return "AddressType" + + return var_type.__name__ diff --git a/src/ape/managers/project/dependency.py b/src/ape/managers/project/dependency.py index e6add04fa1..7fe26818cc 100644 --- a/src/ape/managers/project/dependency.py +++ b/src/ape/managers/project/dependency.py @@ -6,10 +6,9 @@ from typing import Dict, Iterable, List, Optional, Type from ethpm_types import PackageManifest -from ethpm_types.utils import AnyUrl +from pydantic import AnyUrl, FileUrl, HttpUrl, model_validator from semantic_version import NpmSpec, Version # type: ignore -from ape._pydantic_compat import FileUrl, HttpUrl, root_validator from ape.api import DependencyAPI from ape.exceptions import ProjectError, UnknownVersionError from ape.logging import logger @@ -236,7 +235,7 @@ def uri(self) -> AnyUrl: elif self._reference: _uri = f"{_uri}/tree/{self._reference}" - return HttpUrl(_uri, scheme="https") + return HttpUrl(_uri) def __repr__(self): return f"<{self.__class__.__name__} github={self.github}>" @@ -289,7 +288,7 @@ class LocalDependency(DependencyAPI): local: str version: str = "local" - @root_validator() + @model_validator(mode="before") def validate_contracts_folder(cls, value): if value.get("contracts_folder") not in (None, "contracts"): return value @@ -320,7 +319,7 @@ def version_id(self) -> str: @property def uri(self) -> AnyUrl: - return FileUrl(self.path.as_uri(), scheme="file") + return FileUrl(self.path.as_uri()) def extract_manifest(self, use_cache: bool = True) -> PackageManifest: return self._extract_local_manifest(self.path, use_cache=use_cache) @@ -405,7 +404,7 @@ def version_from_local_json(self) -> Optional[str]: @property def uri(self) -> AnyUrl: _uri = f"https://www.npmjs.com/package/{self.npm}/v/{self.version}" - return HttpUrl(_uri, scheme="https") + return HttpUrl(_uri) def extract_manifest(self, use_cache: bool = True) -> PackageManifest: if use_cache and self.cached_manifest: diff --git a/src/ape/managers/project/manager.py b/src/ape/managers/project/manager.py index 6a744c4e51..60b1ce0700 100644 --- a/src/ape/managers/project/manager.py +++ b/src/ape/managers/project/manager.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Type, Union, cast +from eth_pydantic_types import Bip122Uri, HexStr from ethpm_types import ContractInstance as EthPMContractInstance from ethpm_types import ContractType, PackageManifest, PackageMeta, Source -from ethpm_types.contract_type import BIP122_URI from ethpm_types.manifest import PackageName from ethpm_types.source import Compiler, ContractSource -from ethpm_types.utils import AnyUrl, Hex +from pydantic import AnyUrl from ape.api import DependencyAPI, ProjectAPI from ape.contracts import ContractContainer, ContractInstance, ContractNamespace @@ -217,7 +217,7 @@ def meta(self) -> PackageMeta: return self.config_manager.meta @property - def tracked_deployments(self) -> Dict[BIP122_URI, Dict[str, EthPMContractInstance]]: + def tracked_deployments(self) -> Dict[Bip122Uri, Dict[str, EthPMContractInstance]]: """ Deployments that have been explicitly tracked via :meth:`~ape.managers.project.manager.ProjectManager.track_deployment`. @@ -225,17 +225,18 @@ def tracked_deployments(self) -> Dict[BIP122_URI, Dict[str, EthPMContractInstanc of this package. """ - deployments: Dict[BIP122_URI, Dict[str, EthPMContractInstance]] = {} + deployments: Dict[Bip122Uri, Dict[str, EthPMContractInstance]] = {} if not self._package_deployments_folder.is_dir(): return deployments for ecosystem_path in [x for x in self._package_deployments_folder.iterdir() if x.is_dir()]: for deployment_path in [x for x in ecosystem_path.iterdir() if x.suffix == ".json"]: - ethpm_instance = EthPMContractInstance.parse_file(deployment_path) + text = deployment_path.read_text() + ethpm_instance = EthPMContractInstance.model_validate_json(text) if not ethpm_instance: continue - uri = BIP122_URI(f"blockchain://{ecosystem_path.name}/block/{ethpm_instance.block}") + uri = Bip122Uri(f"blockchain://{ecosystem_path.name}/block/{ethpm_instance.block}") deployments[uri] = {deployment_path.stem: ethpm_instance} return deployments @@ -457,7 +458,7 @@ def __getattr__(self, attr_name: str) -> Any: # We know if we get here that the path does not exist. path = self.local_project._cache_folder / f"{ct.name}.json" - path.write_text(ct.json()) + path.write_text(ct.model_dump_json()) if self.local_project._contracts is None: self.local_project._contracts = {ct.name: ct} else: @@ -721,10 +722,10 @@ def track_deployment(self, contract: ContractInstance): block_hash = block_hash_bytes.hex() artifact = EthPMContractInstance( - address=cast(Hex, contract.address), + address=contract.address, block=block_hash, contractType=contract_name, - transaction=cast(Hex, contract.txn_hash), + transaction=cast(HexStr, contract.txn_hash), runtimeBytecode=contract.contract_type.runtime_bytecode, ) @@ -741,7 +742,7 @@ def track_deployment(self, contract: ContractInstance): logger.debug("Deployment already tracked. Re-tracking.") destination.unlink() - destination.write_text(artifact.json()) + destination.write_text(artifact.model_dump_json()) def _create_contract_source(self, contract_type: ContractType) -> Optional[ContractSource]: if not (source_id := contract_type.source_id): @@ -772,7 +773,7 @@ def _get_contract(self, name: str) -> Optional[ContractContainer]: return None # def publish_manifest(self): - # manifest = self.manifest.dict() + # manifest = self.manifest.model_dump(mode="json") # if not manifest["name"]: # raise ProjectError("Need name to release manifest") # if not manifest["version"]: diff --git a/src/ape/managers/project/types.py b/src/ape/managers/project/types.py index 705b1a46ad..203b11345b 100644 --- a/src/ape/managers/project/types.py +++ b/src/ape/managers/project/types.py @@ -187,12 +187,12 @@ def create_manifest( version=self.version, ) # Cache the updated manifest so `self.cached_manifest` reads it next time - self.manifest_cachefile.write_text(manifest.json()) + self.manifest_cachefile.write_text(manifest.model_dump_json()) self._cached_manifest = manifest if compiled_contract_types: for name, contract_type in compiled_contract_types.items(): file = self.project_manager.local_project._cache_folder / f"{name}.json" - file.write_text(contract_type.json()) + file.write_text(contract_type.model_dump_json()) self._contracts = self._contracts or {} self._contracts[name] = contract_type diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 11a061127e..6b4f1a4540 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -17,6 +17,7 @@ from eth_abi.abi import encode from eth_abi.packed import encode_packed +from eth_pydantic_types import HexBytes from eth_typing import Hash32, HexStr from eth_utils import encode_hex, keccak, to_hex from ethpm_types import ( @@ -25,16 +26,15 @@ Checksum, Compiler, ContractType, - HexBytes, PackageManifest, PackageMeta, Source, ) from ethpm_types.abi import EventABI from ethpm_types.source import Closure +from pydantic import BaseModel, field_validator, model_validator from web3.types import FilterParams -from ape._pydantic_compat import BaseModel, root_validator, validator from ape.types.address import AddressType, RawAddress from ape.types.coverage import ( ContractCoverage, @@ -84,7 +84,7 @@ class AutoGasLimit(BaseModel): A multiplier to estimated gas. """ - @validator("multiplier", pre=True) + @field_validator("multiplier", mode="before") def validate_multiplier(cls, value): if isinstance(value, str): return float(value) @@ -136,7 +136,7 @@ class LogFilter(BaseModel): stop_block: Optional[int] = None # Use block height selectors: Dict[str, EventABI] = {} - @root_validator() + @model_validator(mode="before") def compute_selectors(cls, values): values["selectors"] = { encode_hex(keccak(text=event.selector)): event for event in values.get("events", []) @@ -144,17 +144,11 @@ def compute_selectors(cls, values): return values - @validator("start_block", pre=True) + @field_validator("start_block", mode="before") def validate_start_block(cls, value): return value or 0 - @validator("addresses", pre=True, each_item=True) - def validate_addresses(cls, value): - from ape import convert - - return convert(value, AddressType) - - def dict(self, client=None): + def model_dump(self, *args, **kwargs): _Hash32 = Union[Hash32, HexBytes, HexStr] topics = cast(Sequence[Optional[Union[_Hash32, Sequence[_Hash32]]]], self.topic_filter) return FilterParams( @@ -236,7 +230,7 @@ class BaseContractLog(BaseInterfaceModel): event_arguments: Dict[str, Any] = {} """The arguments to the event, including both indexed and non-indexed data.""" - @validator("contract_address", pre=True) + @field_validator("contract_address", mode="before") def validate_address(cls, value): return cls.conversion_manager.convert(value, AddressType) @@ -275,7 +269,7 @@ class ContractLog(BaseContractLog): Is `None` when from the pending block. """ - @validator("block_number", "log_index", "transaction_index", pre=True) + @field_validator("block_number", "log_index", "transaction_index", mode="before") def validate_hex_ints(cls, value): if value is None: # Should only happen for optionals. @@ -286,7 +280,7 @@ def validate_hex_ints(cls, value): return value - @validator("contract_address", pre=True) + @field_validator("contract_address", mode="before") def validate_address(cls, value): return cls.conversion_manager.convert(value, AddressType) diff --git a/src/ape/types/address.py b/src/ape/types/address.py index 9c724225fa..c5bdb46b53 100644 --- a/src/ape/types/address.py +++ b/src/ape/types/address.py @@ -1,13 +1,38 @@ -from typing import Union +from importlib import import_module +from typing import Annotated, Any, Optional, Union -from eth_typing import ChecksumAddress as AddressType -from ethpm_types import HexBytes +from eth_pydantic_types import Address as _Address +from eth_pydantic_types import HashBytes20, HashStr20 +from eth_typing import ChecksumAddress +from pydantic_core.core_schema import ValidationInfo -RawAddress = Union[str, int, HexBytes] +RawAddress = Union[str, int, HashStr20, HashBytes20] """ A raw data-type representation of an address. """ + +class _AddressValidator(_Address): + """ + An address in Ape. This types works the same as + ``eth_pydantic_types.address.Address`` for most cases, + (validated size and checksumming), unless your ecosystem + has a different address type, either in bytes-length or + checksumming algorithm. + """ + + @classmethod + def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str: + return ( + getattr(import_module("ape"), "convert")(value, AddressType) + if value + else "0x0000000000000000000000000000000000000000" + ) + + +AddressType = Annotated[ChecksumAddress, _AddressValidator] + + __all__ = [ "AddressType", "RawAddress", diff --git a/src/ape/types/coverage.py b/src/ape/types/coverage.py index 462b6ca1b1..1c85971277 100644 --- a/src/ape/types/coverage.py +++ b/src/ape/types/coverage.py @@ -9,8 +9,8 @@ import requests from ethpm_types import BaseModel from ethpm_types.source import ContractSource, SourceLocation +from pydantic import NonNegativeInt, field_validator -from ape._pydantic_compat import NonNegativeInt, validator from ape.logging import logger from ape.utils.misc import get_current_timestamp_ms from ape.version import version as ape_version @@ -211,8 +211,8 @@ def line_rate(self) -> float: # This function has hittable statements. return self.lines_covered / self.lines_valid if self.lines_valid > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -335,8 +335,8 @@ def __getitem__(self, function_name: str) -> FunctionCoverage: raise IndexError(f"Function '{function_name}' not found.") - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -436,8 +436,8 @@ def function_rate(self) -> float: """ return self.function_hits / self.total_functions if self.total_functions > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -535,8 +535,8 @@ def function_rate(self) -> float: """ return self.function_hits / self.total_functions if self.total_functions > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -575,7 +575,7 @@ class CoverageReport(BaseModel): Each project with individual coverage tracked. """ - @validator("timestamp", pre=True) + @field_validator("timestamp", mode="before") def validate_timestamp(cls, value): # Default to current UTC timestamp (ms). return value or get_current_timestamp_ms() @@ -998,8 +998,8 @@ def _set_common_td(self, tbody_tr: Any, src_or_fn: Any): stmt_cov_td = SubElement(tbody_tr, "td", {}, **{"class": "column4"}) stmt_cov_td.text = f"{round(src_or_fn.line_rate * 100, 2)}%" - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index ccc613f152..f932810652 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -2,14 +2,15 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Union -from ethpm_types import ASTNode, BaseModel, ContractType, HexBytes +from eth_pydantic_types import HexBytes +from ethpm_types import ASTNode, BaseModel, ContractType from ethpm_types.ast import SourceLocation from ethpm_types.source import Closure, Content, Function, SourceStatement, Statement from evm_trace.gas import merge_reports +from pydantic import Field, RootModel from rich.table import Table from rich.tree import Tree -from ape._pydantic_compat import Field from ape.types.address import AddressType from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import is_evm_precompile, is_zero_hex @@ -323,9 +324,9 @@ def line_numbers(self) -> List[int]: def content(self) -> Content: result: Dict[int, str] = {} for node in self.source_statements: - result = {**result, **node.content.__root__} + result = {**result, **node.content.root} - return Content(__root__=result) + return Content(root=result) @property def source_header(self) -> str: @@ -409,7 +410,7 @@ def extend( new_lines = {no: ln.rstrip() for no, ln in content.items() if no >= content_start} if new_lines: # Add the next statement in this sequence. - content = Content(__root__=new_lines) + content = Content(root=new_lines) statement = SourceStatement(asts=asts, content=content, pcs=pcs) self.statements.append(statement) @@ -471,23 +472,21 @@ def next_statement(self) -> Optional[SourceStatement]: content_dict = {} for ast in next_stmt_asts: sub_content = function.get_content(ast.line_numbers) - content_dict = {**sub_content.__root__} + content_dict = {**sub_content.root} if not content_dict: return None sorted_dict = {k: content_dict[k] for k in sorted(content_dict)} - content = Content(__root__=sorted_dict) + content = Content(root=sorted_dict) return SourceStatement(asts=next_stmt_asts, content=content) -class SourceTraceback(BaseModel): +class SourceTraceback(RootModel[List[ControlFlow]]): """ A full execution traceback including source code. """ - __root__: List[ControlFlow] - @classmethod def create( cls, @@ -497,41 +496,41 @@ def create( ): trace, second_trace = tee(trace) if not second_trace or not (accessor := next(second_trace, None)): - return cls.parse_obj([]) + return cls.model_validate([]) if not (source_id := contract_type.source_id): - return cls.parse_obj([]) + return cls.model_validate([]) ext = f".{source_id.split('.')[-1]}" if ext not in accessor.compiler_manager.registered_compilers: - return cls.parse_obj([]) + return cls.model_validate([]) compiler = accessor.compiler_manager.registered_compilers[ext] try: return compiler.trace_source(contract_type, trace, HexBytes(data)) except NotImplementedError: - return cls.parse_obj([]) + return cls.model_validate([]) def __str__(self) -> str: return self.format() def __repr__(self) -> str: - return f"" + return f"" def __len__(self) -> int: - return len(self.__root__) + return len(self.root) def __iter__(self) -> Iterator[ControlFlow]: # type: ignore[override] - yield from self.__root__ + yield from self.root def __getitem__(self, idx: int) -> ControlFlow: try: - return self.__root__[idx] + return self.root[idx] except IndexError as err: raise IndexError(f"Control flow index '{idx}' out of range.") from err def __setitem__(self, key, value): - return self.__root__.__setitem__(key, value) + return self.root.__setitem__(key, value) @property def revert_type(self) -> Optional[str]: @@ -546,7 +545,7 @@ def append(self, __object) -> None: """ Append the given control flow to this one. """ - self.__root__.append(__object) + self.root.append(__object) def extend(self, __iterable) -> None: """ @@ -555,14 +554,14 @@ def extend(self, __iterable) -> None: if not isinstance(__iterable, SourceTraceback): raise TypeError("Can only extend another traceback object.") - self.__root__.extend(__iterable.__root__) + self.root.extend(__iterable.root) @property def last(self) -> Optional[ControlFlow]: """ The last control flow in the traceback, if there is one. """ - return self.__root__[-1] if len(self.__root__) else None + return self.root[-1] if len(self.root) else None @property def execution(self) -> List[ControlFlow]: @@ -570,27 +569,27 @@ def execution(self) -> List[ControlFlow]: All the control flows in order. Each set of statements in a control flow is separated by a jump. """ - return list(self.__root__) + return list(self.root) @property def statements(self) -> List[Statement]: """ All statements from each control flow. """ - return list(chain(*[x.statements for x in self.__root__])) + return list(chain(*[x.statements for x in self.root])) @property def source_statements(self) -> List[SourceStatement]: """ All source statements from each control flow. """ - return list(chain(*[x.source_statements for x in self.__root__])) + return list(chain(*[x.source_statements for x in self.root])) def format(self) -> str: """ Get a formatted traceback string for displaying to users. """ - if not len(self.__root__): + if not len(self.root): # No calls. return "" @@ -598,7 +597,7 @@ def format(self) -> str: indent = " " last_depth = None segments: List[str] = [] - for control_flow in reversed(self.__root__): + for control_flow in reversed(self.root): if last_depth is None or control_flow.depth == last_depth - 1: if control_flow.depth == 0 and len(segments) >= 1: # Ignore 0-layer segments if source code was hit diff --git a/src/ape/utils/abi.py b/src/ape/utils/abi.py index 8e49b1a89e..611ed15363 100644 --- a/src/ape/utils/abi.py +++ b/src/ape/utils/abi.py @@ -4,8 +4,8 @@ from eth_abi import decode, grammar from eth_abi.exceptions import DecodingError, InsufficientDataBytes +from eth_pydantic_types import HexBytes from eth_utils import decode_hex -from ethpm_types import HexBytes from ethpm_types.abi import ABIType, ConstructorABI, EventABI, EventABIType, MethodABI from ape.logging import logger @@ -106,7 +106,7 @@ def _encode(self, _type: ABIType, value: Any): and isinstance(value, (list, tuple)) and len(_type.components or []) > 0 ): - non_array_type_data = _type.dict() + non_array_type_data = _type.model_dump(mode="json") non_array_type_data["type"] = "tuple" non_array_type = ABIType(**non_array_type_data) return [self._encode(non_array_type, v) for v in value] @@ -151,8 +151,12 @@ def _decode( elif has_array_of_tuples_return: item_type_str = str(_types[0].type).split("[")[0] - data = {**_types[0].dict(), "type": item_type_str, "internalType": item_type_str} - output_type = ABIType.parse_obj(data) + data = { + **_types[0].model_dump(mode="json"), + "type": item_type_str, + "internalType": item_type_str, + } + output_type = ABIType.model_validate(data) if isinstance(values, (list, tuple)) and not values[0]: # Only returned an empty list. @@ -170,11 +174,11 @@ def _decode( if item_type_str == "tuple": # Either an array of structs or nested structs. item_type_data = { - **output_type.dict(), + **output_type.model_dump(mode="json"), "type": item_type_str, "internalType": item_type_str, } - item_type = ABIType.parse_obj(item_type_data) + item_type = ABIType.model_validate(item_type_data) if is_struct(output_type): parsed_item = self._decode([item_type], [value]) @@ -206,7 +210,7 @@ def _create_struct(self, out_abi: ABIType, out_value: Any) -> Optional[Any]: # Likely an empty tuple or not a struct. return None - internal_type = out_abi.internalType + internal_type = out_abi.internal_type if out_abi.name == "" and internal_type and "struct " in internal_type: name = internal_type.replace("struct ", "").split(".")[-1] else: diff --git a/src/ape/utils/basemodel.py b/src/ape/utils/basemodel.py index 5551b6b007..9489ac73d9 100644 --- a/src/ape/utils/basemodel.py +++ b/src/ape/utils/basemodel.py @@ -1,11 +1,12 @@ from abc import ABC from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterator, List, Optional, Union, cast -from ethpm_types import BaseModel as _BaseModel +from ethpm_types import BaseModel as EthpmTypesBaseModel +from pydantic import BaseModel as RootBaseModel +from pydantic import ConfigDict from ape.exceptions import ApeAttributeError, ApeIndexError, ProviderNotConnectedError from ape.logging import logger -from ape.utils.misc import cached_property, singledispatchmethod if TYPE_CHECKING: from ape.api.providers import ProviderAPI @@ -112,7 +113,7 @@ def _get_alt(name: str) -> Optional[str]: return alt -class ExtraModelAttributes(_BaseModel): +class ExtraModelAttributes(EthpmTypesBaseModel): """ A class for defining extra model attributes. """ @@ -124,7 +125,7 @@ class ExtraModelAttributes(_BaseModel): we can show a more accurate exception message. """ - attributes: Union[Dict[str, Any], "BaseModel"] + attributes: Union[Dict[str, Any], RootBaseModel] """The attributes.""" include_getattr: bool = True @@ -140,7 +141,11 @@ class ExtraModelAttributes(_BaseModel): """ def __contains__(self, name: str) -> bool: - attr_dict = self.attributes if isinstance(self.attributes, dict) else self.attributes.dict() + attr_dict = ( + self.attributes + if isinstance(self.attributes, dict) + else self.attributes.model_dump(by_alias=False) + ) if name in attr_dict: return True @@ -179,11 +184,13 @@ def _get(self, name: str) -> Optional[Any]: ) -class BaseModel(_BaseModel): +class BaseModel(EthpmTypesBaseModel): """ An ape-pydantic BaseModel. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: """ Override this method to supply extra attributes @@ -202,6 +209,9 @@ def __getattr__(self, name: str) -> Any: account :meth:`~ape.utils.basemodel.BaseModel.__ape_extra_attributes__`. """ + if name in self.__private_attributes__: + return self.__private_attributes__[name] + try: return super().__getattribute__(name) except AttributeError: @@ -263,7 +273,7 @@ def __getitem__(self, name: Any) -> Any: # The user did not supply any extra __getitem__ attributes. # Do what you would have normally done. - return super().__getitem__(name) + return super().__getitem__(name) # type: ignore class BaseInterfaceModel(BaseInterface, BaseModel): @@ -271,17 +281,6 @@ class BaseInterfaceModel(BaseInterface, BaseModel): An abstract base-class with manager access on a pydantic base model. """ - class Config: - # NOTE: Due to https://github.com/samuelcolvin/pydantic/issues/1241 we have - # to add this cached property workaround in order to avoid this error: - - # TypeError: cannot pickle '_thread.RLock' object - - keep_untouched = (cached_property, singledispatchmethod) - arbitrary_types_allowed = True - underscore_attrs_are_private = True - copy_on_model_validation = "none" - def __dir__(self) -> List[str]: """ **NOTE**: Should integrate options in IPython tab-completion. diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 1c958c3bd3..e6c305bed3 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -9,8 +9,8 @@ import requests import yaml +from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed -from ethpm_types import HexBytes from importlib_metadata import PackageNotFoundError, distributions, packages_distributions from importlib_metadata import version as version_metadata from tqdm.auto import tqdm # type: ignore diff --git a/src/ape/utils/testing.py b/src/ape/utils/testing.py index 3dbe69ab1d..e74deedfb6 100644 --- a/src/ape/utils/testing.py +++ b/src/ape/utils/testing.py @@ -4,7 +4,7 @@ from eth_account import Account from eth_account.hdaccount import HDPath from eth_account.hdaccount.mnemonic import Mnemonic -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes DEFAULT_NUMBER_OF_TEST_ACCOUNTS = 10 DEFAULT_TEST_MNEMONIC = "test test test test test test test test test test test junk" diff --git a/src/ape/utils/trace.py b/src/ape/utils/trace.py index 4ccbf15371..acfe043d9d 100644 --- a/src/ape/utils/trace.py +++ b/src/ape/utils/trace.py @@ -3,8 +3,8 @@ from statistics import mean, median from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed -from ethpm_types import HexBytes from rich.box import SIMPLE from rich.table import Table from rich.tree import Tree diff --git a/src/ape_accounts/accounts.py b/src/ape_accounts/accounts.py index 588e8038a5..5f3035f9cf 100644 --- a/src/ape_accounts/accounts.py +++ b/src/ape_accounts/accounts.py @@ -6,8 +6,9 @@ import click from eth_account import Account as EthAccount from eth_keys import keys # type: ignore +from eth_pydantic_types import HexBytes from eth_utils import to_bytes -from ethpm_types import HexBytes +from pydantic import PrivateAttr from ape.api import AccountAPI, AccountContainerAPI, TransactionAPI from ape.exceptions import AccountsError @@ -47,8 +48,8 @@ def __len__(self) -> int: class KeyfileAccount(AccountAPI): keyfile_path: Path locked: bool = True - __autosign: bool = False - __cached_key: Optional[HexBytes] = None + __autosign: bool = PrivateAttr(default=False) + __cached_key: Optional[HexBytes] = PrivateAttr(default=None) def __repr__(self): # NOTE: Prevent errors from preventing repr from working. diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 59d11d8b7c..38876ddc04 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -352,7 +352,7 @@ def _perform_contract_events_query(self, query: ContractEventQuery) -> Iterator[ # NOTE: Should be unreachable if estimated correctly raise QueryEngineError(f"Could not perform query:\n{query}") - yield from map(lambda row: ContractLog.parse_obj(dict(row.items())), result) + yield from map(lambda row: ContractLog.model_validate(dict(row.items())), result) @singledispatchmethod def _cache_update_clause(self, query: QueryType) -> Insert: diff --git a/src/ape_compile/__init__.py b/src/ape_compile/__init__.py index a829ae2bfc..364807706c 100644 --- a/src/ape_compile/__init__.py +++ b/src/ape_compile/__init__.py @@ -1,7 +1,8 @@ from typing import Any, List, Optional +from pydantic import Field, field_validator + from ape import plugins -from ape._pydantic_compat import Field, validator from ape.api import PluginConfig from ape.logging import logger @@ -26,7 +27,7 @@ class Config(PluginConfig): Source exclusion globs across all file types. """ - @validator("evm_version") + @field_validator("evm_version") def warn_deprecate(cls, value): if value: logger.warning( @@ -36,7 +37,7 @@ def warn_deprecate(cls, value): return None - @validator("exclude", pre=True) + @field_validator("exclude", mode="before") def validate_exclude(cls, value): return value or [] diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index eb87d9064f..4bfd8e5db0 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -4,6 +4,7 @@ from eth_abi import decode, encode from eth_abi.exceptions import InsufficientDataBytes, NonEmptyPaddingBytes +from eth_pydantic_types import HexBytes from eth_typing import Hash32, HexStr from eth_utils import ( encode_hex, @@ -14,10 +15,10 @@ keccak, to_checksum_address, ) -from ethpm_types import ContractType, HexBytes +from ethpm_types import ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI +from pydantic import Field, field_validator -from ape._pydantic_compat import Field, validator from ape.api import BlockAPI, EcosystemAPI, PluginConfig, ReceiptAPI, TransactionAPI from ape.api.networks import LOCAL_NETWORK_NAME from ape.contracts.base import ContractCall @@ -105,13 +106,10 @@ class NetworkConfig(PluginConfig): base_fee_multiplier: float = 1.0 """A multiplier to apply to a transaction base fee.""" - class Config: - smart_union = True - - @validator("gas_limit", pre=True, allow_reuse=True) + @field_validator("gas_limit", mode="before") def validate_gas_limit(cls, value): if isinstance(value, dict) and "auto" in value: - return AutoGasLimit.parse_obj(value["auto"]) + return AutoGasLimit.model_validate(value["auto"]) elif value in ("auto", "max") or isinstance(value, AutoGasLimit): return value @@ -138,9 +136,7 @@ class ForkedNetworkConfig(NetworkConfig): """ -def _create_local_config( - default_provider: Optional[str] = None, use_fork: bool = False, **kwargs -) -> NetworkConfig: +def _create_local_config(default_provider: Optional[str] = None, use_fork: bool = False, **kwargs): return _create_config( base_fee_multiplier=1.0, default_provider=default_provider, @@ -155,7 +151,7 @@ def _create_local_config( def _create_config( required_confirmations: int = 2, base_fee_multiplier: float = DEFAULT_LIVE_NETWORK_BASE_FEE_MULTIPLIER, - cls: Type[NetworkConfig] = NetworkConfig, + cls: Type = NetworkConfig, **kwargs, ) -> NetworkConfig: return cls( @@ -193,7 +189,7 @@ class Block(BlockAPI): EMPTY_BYTES32, alias="parentHash" ) # NOTE: genesis block has no parent hash - @validator( + @field_validator( "base_fee", "difficulty", "gas_limit", @@ -202,7 +198,7 @@ class Block(BlockAPI): "size", "timestamp", "total_difficulty", - pre=True, + mode="before", ) def validate_ints(cls, value): return to_int(value) if value else 0 @@ -406,7 +402,7 @@ def decode_block(self, data: Dict) -> BlockAPI: if "transactions" in data: data["num_transactions"] = len(data["transactions"]) - return Block.parse_obj(data) + return Block.model_validate(data) def _python_type_for_abi_type(self, abi_type: ABIType) -> Union[Type, Tuple, List]: # NOTE: An array can be an array of tuples, so we start with an array check @@ -453,7 +449,7 @@ def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args) -> HexBy def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes) -> Dict: raw_input_types = [i.canonical_type for i in abi.inputs] - input_types = [parse_type(i.dict()) for i in abi.inputs] + input_types = [parse_type(i.model_dump(mode="json")) for i in abi.inputs] try: raw_input_values = decode(raw_input_types, calldata) @@ -486,7 +482,7 @@ def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> Tuple[Any, ...]: elif not isinstance(vm_return_values, (tuple, list)): vm_return_values = (vm_return_values,) - output_types = [parse_type(o.dict()) for o in abi.outputs] + output_types = [parse_type(o.model_dump(mode="json")) for o in abi.outputs] output_values = [ self.decode_primitive_value(v, t) for v, t in zip(vm_return_values, output_types) ] diff --git a/src/ape_ethereum/multicall/handlers.py b/src/ape_ethereum/multicall/handlers.py index 14d948ba1b..60c6cf8918 100644 --- a/src/ape_ethereum/multicall/handlers.py +++ b/src/ape_ethereum/multicall/handlers.py @@ -73,7 +73,7 @@ def contract(self) -> ContractInstance: # else use our backend (with less methods) contract = self.chain_manager.contracts.instance_at( MULTICALL3_ADDRESS, - contract_type=ContractType.parse_obj(MULTICALL3_CONTRACT_TYPE), + contract_type=ContractType.model_validate(MULTICALL3_CONTRACT_TYPE), ) if self.provider.chain_id not in self.supported_chains and contract.code != MULTICALL3_CODE: diff --git a/src/ape_ethereum/proxies.py b/src/ape_ethereum/proxies.py index 40fd31fa8e..387b7828db 100644 --- a/src/ape_ethereum/proxies.py +++ b/src/ape_ethereum/proxies.py @@ -1,9 +1,9 @@ from enum import IntEnum, auto from typing import cast +from eth_pydantic_types.hex import HexStr from ethpm_types import ContractType, MethodABI from ethpm_types.abi import ABIType -from ethpm_types.utils import Hex from lazyasd import LazyObject # type: ignore from ape.api.networks import ProxyInfoAPI @@ -92,7 +92,7 @@ class ProxyInfo(ProxyInfoAPI): def _make_minimal_proxy(address: str = MINIMAL_PROXY_TARGET_PLACEHOLDER) -> ContractContainer: address = address.replace("0x", "") - code = cast(Hex, MINIMAL_PROXY_BYTES.replace(MINIMAL_PROXY_TARGET_PLACEHOLDER, address)) + code = cast(HexStr, MINIMAL_PROXY_BYTES.replace(MINIMAL_PROXY_TARGET_PLACEHOLDER, address)) bytecode = {"bytecode": code} contract_type = ContractType(abi=[], deploymentBytecode=bytecode) return ContractContainer(contract_type=contract_type) diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 3d4c3e59e2..9e9a30cca1 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -8,11 +8,12 @@ encode_transaction, serializable_unsigned_transaction_from_dict, ) +from eth_pydantic_types import HexBytes from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int -from ethpm_types import ContractType, HexBytes +from ethpm_types import ContractType from ethpm_types.abi import EventABI, MethodABI +from pydantic import BaseModel, Field, field_validator, model_validator -from ape._pydantic_compat import BaseModel, Field, root_validator, validator from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError @@ -56,7 +57,7 @@ def serialize_transaction(self) -> bytes: if not self.signature: raise SignatureError("The transaction is not signed.") - txn_data = self.dict(exclude={"sender"}) + txn_data = self.model_dump(exclude={"sender", "type"}) unsigned_txn = serializable_unsigned_transaction_from_dict(txn_data) signature = (self.signature.v, to_int(self.signature.r), to_int(self.signature.s)) signed_txn = encode_transaction(unsigned_txn, signature) @@ -83,11 +84,11 @@ class StaticFeeTransaction(BaseTransaction): """ gas_price: Optional[int] = Field(None, alias="gasPrice") - max_priority_fee: Optional[int] = Field(None, exclude=True) + max_priority_fee: Optional[int] = Field(None, exclude=True) # type: ignore type: int = Field(TransactionType.STATIC.value, exclude=True) - max_fee: Optional[int] = Field(None, exclude=True) + max_fee: Optional[int] = Field(None, exclude=True) # type: ignore - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def calculate_read_only_max_fee(cls, values) -> Dict: # NOTE: Work-around, Pydantic doesn't handle calculated fields well. values["max_fee"] = values.get("gas_limit", 0) * values.get("gas_price", 0) @@ -100,12 +101,12 @@ class DynamicFeeTransaction(BaseTransaction): and ``maxPriorityFeePerGas`` fields. """ - max_priority_fee: Optional[int] = Field(None, alias="maxPriorityFeePerGas") - max_fee: Optional[int] = Field(None, alias="maxFeePerGas") + max_priority_fee: Optional[int] = Field(None, alias="maxPriorityFeePerGas") # type: ignore + max_fee: Optional[int] = Field(None, alias="maxFeePerGas") # type: ignore type: int = Field(TransactionType.DYNAMIC.value) access_list: List[AccessList] = Field(default_factory=list, alias="accessList") - @validator("type", allow_reuse=True) + @field_validator("type") def check_type(cls, value): return value.value if isinstance(value, TransactionType) else value @@ -119,7 +120,7 @@ class AccessListTransaction(BaseTransaction): type: int = Field(TransactionType.ACCESS_LIST.value) access_list: List[AccessList] = Field(default_factory=list, alias="accessList") - @validator("type", allow_reuse=True) + @field_validator("type") def check_type(cls, value): return value.value if isinstance(value, TransactionType) else value @@ -173,7 +174,7 @@ def source_traceback(self) -> SourceTraceback: if contract_type := self.contract_type: return SourceTraceback.create(contract_type, self.trace, HexBytes(self.data)) - return SourceTraceback.parse_obj([]) + return SourceTraceback.model_validate([]) def raise_for_status(self): if self.gas_limit is not None and self.ran_out_of_gas: @@ -276,7 +277,7 @@ def get_default_log( return ContractLog( block_hash=self.block.hash, block_number=self.block_number, - event_arguments={"__root__": _log["data"]}, + event_arguments={"root": _log["data"]}, event_name=f"<{name}>", log_index=logs[-1].log_index + 1 if logs else 0, transaction_hash=self.txn_hash, diff --git a/src/ape_geth/__init__.py b/src/ape_geth/__init__.py index 05b66fe3bd..4975061306 100644 --- a/src/ape_geth/__init__.py +++ b/src/ape_geth/__init__.py @@ -13,7 +13,7 @@ def config_class(): @plugins.register(plugins.ProviderPlugin) def providers(): - networks_dict = GethNetworkConfig().dict() + networks_dict = GethNetworkConfig().model_dump(mode="json") networks_dict.pop(LOCAL_NETWORK_NAME) for network_name in networks_dict: yield "ethereum", network_name, GethProvider diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index 27dde32017..46a2c3d1a6 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -11,9 +11,9 @@ import ijson # type: ignore import requests +from eth_pydantic_types import HexBytes from eth_typing import HexStr from eth_utils import add_0x_prefix, to_hex, to_wei -from ethpm_types import HexBytes from evm_trace import CallType, ParityTraceList from evm_trace import TraceFrame as EvmTraceFrame from evm_trace import ( @@ -26,6 +26,7 @@ from geth.chain import initialize_chain # type: ignore from geth.process import BaseGethProcess # type: ignore from geth.wrapper import construct_test_chain_kwargs # type: ignore +from pydantic_settings import SettingsConfigDict from requests.exceptions import ConnectionError from web3 import HTTPProvider, Web3 from web3.exceptions import ExtraDataLengthError @@ -36,7 +37,6 @@ from web3.providers.auto import load_provider_from_environment from yarl import URL -from ape._pydantic_compat import Extra from ape.api import ( PluginConfig, ReceiptAPI, @@ -233,9 +233,7 @@ class GethConfig(PluginConfig): ipc_path: Optional[Path] = None data_dir: Optional[Path] = None - class Config: - # For allowing all other EVM-based ecosystem plugins - extra = Extra.allow + model_config = SettingsConfigDict(extra="allow") class GethNotInstalledError(ConnectionError): @@ -250,8 +248,6 @@ def __init__(self): class BaseGethProvider(Web3Provider, ABC): - _client_version: Optional[str] = None - # optimal values for geth block_page_size: int = 5000 concurrency: int = 16 @@ -259,7 +255,7 @@ class BaseGethProvider(Web3Provider, ABC): name: str = "geth" """Is ``None`` until known.""" - _can_use_parity_traces: Optional[bool] = None + can_use_parity_traces: Optional[bool] = None @property def uri(self) -> str: @@ -267,7 +263,7 @@ def uri(self) -> str: # Use adhoc, scripted value return self.provider_settings["uri"] - config = self.settings.dict().get(self.network.ecosystem.name, None) + config = self.config.model_dump(mode="json").get(self.network.ecosystem.name, None) if config is None: return DEFAULT_SETTINGS["uri"] @@ -307,8 +303,9 @@ def _ots_api_level(self) -> Optional[int]: return None def _set_web3(self): - self._client_version = None # Clear cached version when connecting to another URI. - self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path) + # Clear cached version when connecting to another URI. + self.cached_client_version = None # type: ignore + self.cached_web3 = _create_web3(self.uri, ipc_path=self.ipc_path) def _complete_connect(self): client_version = self.client_version.lower() @@ -357,9 +354,9 @@ def _complete_connect(self): self.network.verify_chain_id(chain_id) def disconnect(self): - self._can_use_parity_traces = None - self._web3 = None - self._client_version = None + self.can_use_parity_traces = None + self.cached_web3 = None # type: ignore + self.cached_client_version = None # type: ignore def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]: frames = self._stream_request( @@ -374,26 +371,26 @@ def _get_transaction_trace_using_call_tracer(self, txn_hash: str) -> Dict: ) def get_call_tree(self, txn_hash: str) -> CallTreeNode: - if self._can_use_parity_traces is True: + if self.can_use_parity_traces is True: return self._get_parity_call_tree(txn_hash) - elif self._can_use_parity_traces is False: + elif self.can_use_parity_traces is False: return self._get_geth_call_tree(txn_hash) elif "erigon" in self.client_version.lower(): tree = self._get_parity_call_tree(txn_hash) - self._can_use_parity_traces = True + self.can_use_parity_traces = True return tree try: # Try the Parity traces first, in case node client supports it. tree = self._get_parity_call_tree(txn_hash) except (ValueError, APINotImplementedError, ProviderError): - self._can_use_parity_traces = False + self.can_use_parity_traces = False return self._get_geth_call_tree(txn_hash) # Parity style works. - self._can_use_parity_traces = True + self.can_use_parity_traces = True return tree def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: @@ -401,7 +398,7 @@ def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: if not result: raise ProviderError(f"Failed to get trace for '{txn_hash}'.") - traces = ParityTraceList.parse_obj(result) + traces = ParityTraceList.model_validate(result) evm_call = get_calltree_from_parity_trace(traces) return self._create_call_tree_node(evm_call, txn_hash=txn_hash) @@ -464,9 +461,9 @@ def _stream_request(self, method: str, params: List, iter_path="result.item"): class GethDev(BaseGethProvider, TestProviderAPI, SubprocessProvider): - _process: Optional[GethDevProcess] = None + cached_process: Optional[GethDevProcess] = None name: str = "geth" - _can_use_parity_traces = False + can_use_parity_traces: Optional[bool] = False @property def process_name(self) -> str: @@ -495,7 +492,7 @@ def connect(self): self.start() def start(self, timeout: int = 20): - test_config = self.config_manager.get_config("test").dict() + test_config = self.config_manager.get_config("test").model_dump(mode="json") # Allow configuring a custom executable besides your $PATH geth. if self.settings.executable is not None: diff --git a/src/ape_plugins/_cli.py b/src/ape_plugins/_cli.py index c1607fa781..ba70caead7 100644 --- a/src/ape_plugins/_cli.py +++ b/src/ape_plugins/_cli.py @@ -36,7 +36,7 @@ def load_from_file(ctx, file_path: Path) -> List[PluginMetadata]: if file_path.is_file(): config = load_config(file_path) if plugins := config.get("plugins"): - return [PluginMetadata.parse_obj(d) for d in plugins] + return [PluginMetadata.model_validate(d) for d in plugins] ctx.obj.logger.warning(f"No plugins found at '{file_path}'.") return [] diff --git a/src/ape_plugins/utils.py b/src/ape_plugins/utils.py index 00dba29df0..0fdd167591 100644 --- a/src/ape_plugins/utils.py +++ b/src/ape_plugins/utils.py @@ -4,8 +4,9 @@ from functools import cached_property from typing import Iterator, List, Optional, Sequence, Set, Tuple +from pydantic import model_validator + from ape.__modules__ import __modules__ -from ape._pydantic_compat import root_validator from ape.logging import logger from ape.plugins import clean_plugin_name from ape.utils import BaseInterfaceModel, get_package_version, github_client @@ -86,7 +87,7 @@ class PluginMetadataList(BaseModel): @classmethod def from_package_names(cls, packages: Sequence[str]) -> "PluginMetadataList": - PluginMetadataList.update_forward_refs() + PluginMetadataList.model_rebuild() core = PluginGroup(plugin_type=PluginType.CORE) available = PluginGroup(plugin_type=PluginType.AVAILABLE) installed = PluginGroup(plugin_type=PluginType.INSTALLED) @@ -136,7 +137,7 @@ class PluginMetadata(BaseInterfaceModel): version: Optional[str] = None """The version requested, if there is one.""" - @root_validator(pre=True) + @model_validator(mode="before") def validate_name(cls, values): if "name" not in values: raise ValueError("'name' required.") diff --git a/src/ape_pm/compiler.py b/src/ape_pm/compiler.py index d4e1d9770b..1656ba383e 100644 --- a/src/ape_pm/compiler.py +++ b/src/ape_pm/compiler.py @@ -2,8 +2,9 @@ from pathlib import Path from typing import List, Optional, Set +from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed -from ethpm_types import ContractType, HexBytes +from ethpm_types import ContractType from ape.api import CompilerAPI from ape.exceptions import CompilerError, ContractLogicError diff --git a/src/ape_test/__init__.py b/src/ape_test/__init__.py index 53bf1de479..6a4507aecc 100644 --- a/src/ape_test/__init__.py +++ b/src/ape_test/__init__.py @@ -1,7 +1,8 @@ from typing import Dict, List, NewType, Optional, Union +from pydantic import NonNegativeInt + from ape import plugins -from ape._pydantic_compat import NonNegativeInt from ape.api import PluginConfig from ape.api.networks import LOCAL_NETWORK_NAME from ape.utils import DEFAULT_HD_PATH, DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_MNEMONIC diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index f0f18e8219..69bbff1bfe 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -10,31 +10,29 @@ class TestAccountContainer(TestAccountContainerAPI): - _num_generated: int - _accounts: List["TestAccount"] - _mnemonic: str - _num_of_accounts: int - _hd_path: str + num_generated: int = 0 + cached_accounts: List["TestAccount"] = [] + mnemonic: str = "" + num_of_accounts: int = 0 + hd_path: str = "" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init() def init(self): - self._num_generated = 0 - self._accounts = [] - self._mnemonic = self.config["mnemonic"] - self._num_of_accounts = self.config["number_of_accounts"] - self._hd_path = self.config["hd_path"] + self.mnemonic = self.config["mnemonic"] + self.num_of_accounts = self.config["number_of_accounts"] + self.hd_path = self.config["hd_path"] for index, account in enumerate(self._dev_accounts): - self._accounts.append( + self.cached_accounts.append( TestAccount( index=index, address_str=account.address, private_key=account.private_key ) ) def __len__(self) -> int: - return self._num_of_accounts + return self.num_of_accounts @property def config(self): @@ -43,14 +41,14 @@ def config(self): @property def _dev_accounts(self) -> List[GeneratedDevAccount]: return generate_dev_accounts( - self._mnemonic, - number_of_accounts=self._num_of_accounts, - hd_path_format=self._hd_path, + self.mnemonic, + number_of_accounts=self.num_of_accounts, + hd_path_format=self.hd_path, ) @property def aliases(self) -> Iterator[str]: - for index in range(self._num_of_accounts): + for index in range(self.num_of_accounts): yield f"TEST::{index}" def _is_config_changed(self): @@ -58,9 +56,9 @@ def _is_config_changed(self): current_number = self.config["number_of_accounts"] current_hd_path = self.config["hd_path"] return ( - self._mnemonic != current_mnemonic - or self._num_of_accounts != current_number - or self._hd_path != current_hd_path + self.mnemonic != current_mnemonic + or self.num_of_accounts != current_number + or self.hd_path != current_hd_path ) @property @@ -69,14 +67,14 @@ def accounts(self) -> Iterator["TestAccount"]: config_changed = self._is_config_changed() if config_changed: self.init() - for account in self._accounts: + for account in self.cached_accounts: yield account def generate_account(self) -> "TestAccountAPI": - new_index = self._num_of_accounts + self._num_generated - self._num_generated += 1 + new_index = self.num_of_accounts + self.num_generated + self.num_generated += 1 generated_account = generate_dev_accounts( - self._mnemonic, 1, hd_path_format=self._hd_path, start_index=new_index + self.mnemonic, 1, hd_path_format=self.hd_path, start_index=new_index )[0] acc = TestAccount( index=new_index, @@ -111,7 +109,7 @@ def sign_message(self, msg: SignableMessage) -> Optional[MessageSignature]: def sign_transaction(self, txn: TransactionAPI, **kwargs) -> Optional[TransactionAPI]: # Signs anything that's given to it - signature = EthAccount.sign_transaction(txn.dict(), self.private_key) + signature = EthAccount.sign_transaction(txn.model_dump(mode="json"), self.private_key) txn.signature = TransactionSignature( v=signature.v, r=to_bytes(signature.r), diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index a270e52055..dd5eb620a7 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -1,14 +1,15 @@ import re from ast import literal_eval +from re import Pattern from typing import Dict, Optional, cast from eth.exceptions import HeaderNotFound +from eth_pydantic_types import HexBytes from eth_tester.backends import PyEVMBackend # type: ignore from eth_tester.exceptions import TransactionFailed # type: ignore from eth_utils import is_0x_prefixed from eth_utils.exceptions import ValidationError from eth_utils.toolz import merge -from ethpm_types import HexBytes from web3 import EthereumTesterProvider, Web3 from web3.exceptions import ContractPanicError from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return @@ -32,44 +33,49 @@ class EthTesterProviderConfig(PluginConfig): class LocalProvider(TestProviderAPI, Web3Provider): - _evm_backend: Optional[PyEVMBackend] = None - _CANNOT_AFFORD_GAS_PATTERN = re.compile( + cached_evm_backend: Optional[PyEVMBackend] = None + CANNOT_AFFORD_GAS_PATTERN: Pattern = re.compile( r"Sender b'[\\*|\w]*' cannot afford txn gas (\d+) with account balance (\d+)" ) - _INVALID_NONCE_PATTERN = re.compile(r"Invalid transaction nonce: Expected (\d+), but got (\d+)") + INVALID_NONCE_PATTERN: Pattern = re.compile( + r"Invalid transaction nonce: Expected (\d+), but got (\d+)" + ) @property def evm_backend(self) -> PyEVMBackend: - if self._evm_backend is None: + if self.cached_evm_backend is None: raise ProviderNotConnectedError() - return self._evm_backend + return self.cached_evm_backend def connect(self): chain_id = self.settings.chain_id - if self._web3 is not None: + if self.cached_web3 is not None: connected_chain_id = self._make_request("eth_chainId") if connected_chain_id == chain_id: # Is already connected and settings have not changed. return - self._evm_backend = PyEVMBackend.from_mnemonic( + self.cached_evm_backend = PyEVMBackend.from_mnemonic( mnemonic=self.config.mnemonic, num_accounts=self.config.number_of_accounts, ) endpoints = {**API_ENDPOINTS} endpoints["eth"] = merge(endpoints["eth"], {"chainId": static_return(chain_id)}) - tester = EthereumTesterProvider(ethereum_tester=self._evm_backend, api_endpoints=endpoints) - self._web3 = Web3(tester) + tester = EthereumTesterProvider( + ethereum_tester=self.cached_evm_backend, api_endpoints=endpoints + ) + self.cached_web3 = Web3(tester) def disconnect(self): - self.cached_chain_id = None - self._web3 = None - self._evm_backend = None + # NOTE: This type ignore seems like a bug in pydantic. + self.cached_chain_id = None # type: ignore + self.cached_web3 = None # type: ignore + self.cached_evm_backend = None # type: ignore self.provider_settings = {} def update_settings(self, new_settings: Dict): - self.cached_chain_id = None + self.cached_chain_id = None # type: ignore[assignment] self.provider_settings = {**self.provider_settings, **new_settings} self.disconnect() self.connect() @@ -84,22 +90,23 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) estimate_gas = self.web3.eth.estimate_gas - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") txn_dict.pop("gas", None) - txn_data = cast(TxParams, txn_dict) + try: return estimate_gas(txn_data, block_identifier=block_id) except (ValidationError, TransactionFailed) as err: ape_err = self.get_virtual_machine_error(err, txn=txn) - gas_match = self._INVALID_NONCE_PATTERN.match(str(ape_err)) + gas_match = self.INVALID_NONCE_PATTERN.match(str(ape_err)) if gas_match: # Sometimes, EthTester is confused about the sender nonce # during gas estimation. Retry using the "expected" gas # and then set it back. expected_nonce, actual_nonce = gas_match.groups() txn.nonce = int(expected_nonce) - value = estimate_gas(txn.dict(), block_identifier=block_id) # type: ignore + txn_params: TxParams = cast(TxParams, txn.model_dump(mode="json")) + value = estimate_gas(txn_params, block_identifier=block_id) txn.nonce = int(actual_nonce) return value @@ -113,13 +120,13 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: @property def settings(self) -> EthTesterProviderConfig: - return EthTesterProviderConfig.parse_obj( - {**self.config.provider.dict(), **self.provider_settings} + return EthTesterProviderConfig.model_validate( + {**self.config.provider.model_dump(mode="json"), **self.provider_settings} ) @property def chain_id(self) -> int: - if self.cached_chain_id: + if self.cached_chain_id is not None: return self.cached_chain_id try: @@ -148,7 +155,7 @@ def base_fee(self) -> int: return self._get_last_base_fee() def send_call(self, txn: TransactionAPI, **kwargs) -> bytes: - data = txn.dict(exclude_none=True) + data = txn.model_dump(mode="json", exclude_none=True) block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) state = kwargs.pop("state_override", None) call_kwargs = {"block_identifier": block_id, "state_override": state} @@ -186,7 +193,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: self.chain_manager.history.append(receipt) if receipt.failed: - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") txn_dict["nonce"] += 1 txn_params = cast(TxParams, txn_dict) @@ -238,7 +245,7 @@ def mine(self, num_blocks: int = 1): def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: if isinstance(exception, ValidationError): - match = self._CANNOT_AFFORD_GAS_PATTERN.match(str(exception)) + match = self.CANNOT_AFFORD_GAS_PATTERN.match(str(exception)) if match: txn_gas, bal = match.groups() sender = getattr(kwargs["txn"], "sender") diff --git a/tests/conftest.py b/tests/conftest.py index 66590fbce9..f074de63e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -214,9 +214,9 @@ def eth_tester_provider(ethereum): @pytest.fixture def mock_provider(mock_web3, eth_tester_provider): web3 = eth_tester_provider.web3 - eth_tester_provider._web3 = mock_web3 + eth_tester_provider.cached_web3 = mock_web3 yield eth_tester_provider - eth_tester_provider._web3 = web3 + eth_tester_provider.cached_web3 = web3 @pytest.fixture diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index f1babaf700..68b4201efd 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -5,7 +5,8 @@ from typing import Optional, cast import pytest -from ethpm_types import ContractType, HexBytes, MethodABI +from eth_pydantic_types import HexBytes +from ethpm_types import ContractType, MethodABI import ape from ape.api import TransactionAPI @@ -24,7 +25,8 @@ @pytest.fixture(scope="session") def get_contract_type(): def fn(name: str) -> ContractType: - return ContractType.parse_file(CONTRACTS_FOLDER / f"{name}.json") + content = (CONTRACTS_FOLDER / f"{name}.json").read_text() + return ContractType.model_validate_json(content) return fn @@ -36,7 +38,7 @@ def fn(name: str) -> ContractType: APE_PROJECT_FOLDER = BASE_PROJECTS_DIRECTORY / "ApeProject" BASE_SOURCES_DIRECTORY = (Path(__file__).parent / "data/sources").absolute() -CALL_WITH_STRUCT_INPUT = MethodABI.parse_obj( +CALL_WITH_STRUCT_INPUT = MethodABI.model_validate( { "type": "function", "name": "getTradeableOrderWithSignature", @@ -84,7 +86,7 @@ def fn(name: str) -> ContractType: ], } ) -METHOD_WITH_STRUCT_INPUT = MethodABI.parse_obj( +METHOD_WITH_STRUCT_INPUT = MethodABI.model_validate( { "type": "function", "name": "getTradeableOrderWithSignature", @@ -354,7 +356,7 @@ def contract_getter(address): / "mainnet" / f"{address}.json" ) - contract = ContractType.parse_file(path) + contract = ContractType.model_validate_json(path.read_text()) chain.contracts._local_contract_types[address] = contract return contract diff --git a/tests/functional/conversion/test_encode_structs.py b/tests/functional/conversion/test_encode_structs.py index cd0373a9fa..c113cbd3e4 100644 --- a/tests/functional/conversion/test_encode_structs.py +++ b/tests/functional/conversion/test_encode_structs.py @@ -1,13 +1,13 @@ from typing import Dict, Tuple, cast import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes +from ethpm_types import BaseModel from ethpm_types.abi import MethodABI -from ape._pydantic_compat import BaseModel from ape.types import AddressType -ABI = MethodABI.parse_obj( +ABI = MethodABI.model_validate( { "type": "function", "name": "test", diff --git a/tests/functional/conversion/test_hex.py b/tests/functional/conversion/test_hex.py index f8423e2daa..a950d841d2 100644 --- a/tests/functional/conversion/test_hex.py +++ b/tests/functional/conversion/test_hex.py @@ -1,5 +1,5 @@ import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape.exceptions import ConversionError from ape.managers.converters import HexConverter, HexIntConverter diff --git a/tests/functional/data/python/__init__.py b/tests/functional/data/python/__init__.py index dd42dee30c..f6dd0b4d53 100644 --- a/tests/functional/data/python/__init__.py +++ b/tests/functional/data/python/__init__.py @@ -1,4 +1,4 @@ -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape_ethereum.transactions import TransactionStatusEnum diff --git a/tests/functional/geth/conftest.py b/tests/functional/geth/conftest.py index 6bf6c3c771..cc3bd678e1 100644 --- a/tests/functional/geth/conftest.py +++ b/tests/functional/geth/conftest.py @@ -68,10 +68,10 @@ def mock_geth(geth_provider, mock_web3): data_folder=Path("."), request_header={}, ) - original_web3 = provider._web3 - provider._web3 = mock_web3 + original_web3 = provider.cached_web3 + provider.cached_web3 = mock_web3 yield provider - provider._web3 = original_web3 + provider.cached_web3 = original_web3 @pytest.fixture diff --git a/tests/functional/test_block.py b/tests/functional/test_block.py index d9398bc7c3..e3b3e581ca 100644 --- a/tests/functional/test_block.py +++ b/tests/functional/test_block.py @@ -7,7 +7,7 @@ def block(chain): def test_block_dict(block): - actual = block.dict() + actual = block.model_dump(mode="json") expected = { "baseFeePerGas": 1000000000, "difficulty": 0, @@ -25,7 +25,7 @@ def test_block_dict(block): def test_block_json(block): - actual = block.json() + actual = block.model_dump_json() expected = ( '{"baseFeePerGas":1000000000,"difficulty":0,"gasLimit":30029122,"gasUsed":0,' f'"hash":"{block.hash.hex()}",' diff --git a/tests/functional/test_block_container.py b/tests/functional/test_block_container.py index 68ba9bc3ff..72ff4e3036 100644 --- a/tests/functional/test_block_container.py +++ b/tests/functional/test_block_container.py @@ -3,7 +3,7 @@ from typing import List import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape.exceptions import ChainError @@ -61,7 +61,7 @@ def test_block_range_negative_start(chain_that_mined_5): with pytest.raises(ValueError) as err: _ = [b for b in chain_that_mined_5.blocks.range(-1, 3, step=2)] - assert "ensure this value is greater than or equal to 0" in str(err.value) + assert "Input should be greater than or equal to 0" in str(err.value) def test_block_range_out_of_order(chain_that_mined_5): diff --git a/tests/functional/test_config.py b/tests/functional/test_config.py index 0b614c6cf2..4a5aef7eda 100644 --- a/tests/functional/test_config.py +++ b/tests/functional/test_config.py @@ -16,9 +16,9 @@ def test_integer_deployment_addresses(networks): "valid_ecosystems": {"ethereum": networks.ethereum}, "valid_networks": [LOCAL_NETWORK_NAME], } - config = DeploymentConfigCollection(__root__=data) + config = DeploymentConfigCollection(root=data) assert ( - config.__root__["ethereum"]["local"][0]["address"] + config.root["ethereum"]["local"][0]["address"] == "0x0c25212c557d00024b7Ca3df3238683A35541354" ) @@ -35,7 +35,7 @@ def test_bad_value_in_deployments( ecosystem_dict = {e: all_ecosystems[e] for e in ecosystem_names if e in all_ecosystems} data = {**deployments, "valid_ecosystems": ecosystem_dict, "valid_networks": network_names} ape_caplog.assert_last_log_with_retries( - lambda: DeploymentConfigCollection(__root__=data), + lambda: DeploymentConfigCollection(root=data), f"Invalid {err_part}", ) diff --git a/tests/functional/test_contract.py b/tests/functional/test_contract.py index ed22e31397..5513be7fb6 100644 --- a/tests/functional/test_contract.py +++ b/tests/functional/test_contract.py @@ -15,7 +15,8 @@ def test_contract_from_abi(contract_instance): def test_contract_from_abi_list(contract_instance): contract = Contract( - contract_instance.address, abi=[abi.dict() for abi in contract_instance.contract_type.abi] + contract_instance.address, + abi=[abi.model_dump(mode="json") for abi in contract_instance.contract_type.abi], ) assert isinstance(contract, ContractInstance) @@ -26,7 +27,9 @@ def test_contract_from_abi_list(contract_instance): def test_contract_from_json_str(contract_instance): contract = Contract( contract_instance.address, - abi=json.dumps([abi.dict() for abi in contract_instance.contract_type.abi]), + abi=json.dumps( + [abi.model_dump(mode="json") for abi in contract_instance.contract_type.abi] + ), ) assert isinstance(contract, ContractInstance) diff --git a/tests/functional/test_contract_event.py b/tests/functional/test_contract_event.py index 1d4d951a6d..0b6e794d56 100644 --- a/tests/functional/test_contract_event.py +++ b/tests/functional/test_contract_event.py @@ -3,8 +3,9 @@ from typing import Optional import pytest +from eth_pydantic_types import HexBytes from eth_utils import to_hex -from ethpm_types import ContractType, HexBytes +from ethpm_types import ContractType from ape.api import ReceiptAPI from ape.exceptions import ChainError @@ -246,8 +247,8 @@ def test_contract_two_events_with_same_name( ): interface_path = contracts_folder / "Interface.json" impl_path = contracts_folder / "InterfaceImplementation.json" - interface_contract_type = ContractType.parse_raw(interface_path.read_text()) - impl_contract_type = ContractType.parse_raw(impl_path.read_text()) + interface_contract_type = ContractType.model_validate_json(interface_path.read_text()) + impl_contract_type = ContractType.model_validate_json(impl_path.read_text()) event_name = "FooEvent" # Ensure test is setup correctly in case scenario-data changed on accident diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index 48524e6787..6731f31ce9 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -2,11 +2,11 @@ from typing import List, Tuple import pytest +from eth_pydantic_types import HexBytes from eth_utils import is_checksum_address, to_hex -from ethpm_types import ContractType, HexBytes +from ethpm_types import BaseModel, ContractType from ape import Contract -from ape._pydantic_compat import BaseModel from ape.api import TransactionAPI from ape.contracts import ContractInstance from ape.exceptions import ( @@ -255,6 +255,7 @@ def test_nested_structs(contract_instance, owner, chain): == chain.blocks[-2].hash ) assert isinstance(actual_1.t.b, bytes) + assert getattr(actual_1.t.b, "__orig_class__") is HexBytes assert ( actual_2.t.b == actual_2.t["b"] @@ -263,6 +264,7 @@ def test_nested_structs(contract_instance, owner, chain): == chain.blocks[-2].hash ) assert isinstance(actual_2.t.b, bytes) + assert getattr(actual_2.t.b, "__orig_class__") is HexBytes def test_nested_structs_in_tuples(contract_instance, owner, chain): @@ -754,13 +756,13 @@ def test_value_to_non_payable_fallback_and_no_receive( and you try to send a value, it fails. """ # Hack to set fallback as non-payable. - contract_type_data = vyper_fallback_contract_type.dict() + contract_type_data = vyper_fallback_contract_type.model_dump(mode="json") for abi in contract_type_data["abi"]: if abi.get("type") == "fallback": abi["stateMutability"] = "non-payable" break - new_contract_type = ContractType.parse_obj(contract_type_data) + new_contract_type = ContractType.model_validate(contract_type_data) contract = owner.chain_manager.contracts.instance_at( vyper_fallback_contract.address, contract_type=new_contract_type ) diff --git a/tests/functional/test_contract_method_handler.py b/tests/functional/test_contract_method_handler.py index 06433cf181..60e8dbe62a 100644 --- a/tests/functional/test_contract_method_handler.py +++ b/tests/functional/test_contract_method_handler.py @@ -1,4 +1,4 @@ -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape.contracts.base import ContractMethodHandler diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 3c165eef8e..29a10e47ff 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -1,8 +1,8 @@ from typing import Any, Dict import pytest +from eth_pydantic_types import HexBytes from eth_typing import HexAddress, HexStr -from ethpm_types import HexBytes from ethpm_types.abi import ABIType, EventABI, MethodABI from ape.api.networks import LOCAL_NETWORK_NAME @@ -91,10 +91,10 @@ def test_block_handles_snake_case_parent_hash(eth_tester_provider, sender, recei # Replace 'parentHash' key with 'parent_hash' latest_block = eth_tester_provider.get_block("latest") - latest_block_dict = eth_tester_provider.get_block("latest").dict() + latest_block_dict = eth_tester_provider.get_block("latest").model_dump(mode="json") latest_block_dict["parent_hash"] = latest_block_dict.pop("parentHash") - redefined_block = Block.parse_obj(latest_block_dict) + redefined_block = Block.model_validate(latest_block_dict) assert redefined_block.parent_hash == latest_block.parent_hash @@ -355,7 +355,7 @@ def test_decode_return_data_non_empty_padding_bytes(ethereum): "000000000000000000000000000000000000000000000000000000000000012696e73756666" "696369656e742066756e64730000000000000000000000000000" ) - abi = MethodABI.parse_obj( + abi = MethodABI.model_validate( { "type": "function", "name": "transfer", diff --git a/tests/functional/test_multicall.py b/tests/functional/test_multicall.py index c678bbe9b8..2e1a020d7f 100644 --- a/tests/functional/test_multicall.py +++ b/tests/functional/test_multicall.py @@ -1,7 +1,8 @@ from typing import List import pytest -from ethpm_types import ContractType, HexBytes +from eth_pydantic_types import HexBytes +from ethpm_types import ContractType from ape.exceptions import APINotImplementedError from ape_ethereum.multicall import Call diff --git a/tests/functional/test_plugins.py b/tests/functional/test_plugins.py index c6e551e1ea..fc47f27da4 100644 --- a/tests/functional/test_plugins.py +++ b/tests/functional/test_plugins.py @@ -63,7 +63,7 @@ def test_names(self, name): assert metadata.package_name == "ape-foo-bar" assert metadata.module_name == "ape_foo_bar" - def test_model_when_version_included_with_name(self): + def test_model_validator_when_version_included_with_name(self): # This allows parsing requirements files easier metadata = PluginMetadata(name="ape-foo-bar==0.5.0") assert metadata.name == "foo-bar" diff --git a/tests/functional/test_project.py b/tests/functional/test_project.py index 1b81ca2a6d..ac8aa4b0b5 100644 --- a/tests/functional/test_project.py +++ b/tests/functional/test_project.py @@ -72,7 +72,7 @@ def contract_type_1(vyper_contract_type): def existing_source_path(vyper_contract_type, contract_type_0, contracts_folder): source_path = contracts_folder / "NewContract_0.json" source_path.touch() - source_path.write_text(contract_type_0.json()) + source_path.write_text(contract_type_0.model_dump_json()) yield source_path if source_path.is_file(): source_path.unlink() @@ -87,9 +87,9 @@ def manifest_with_non_existent_sources( manifest.contract_types["NewContract_1"] = contract_type_1 # Previous refs shouldn't interfere (bugfix related) manifest.sources["NewContract_0.json"] = Source( - content=contract_type_0.json(), references=["NewContract_1.json"] + content=contract_type_0.model_dump_json(), references=["NewContract_1.json"] ) - manifest.sources["NewContract_1.json"] = Source(content=contract_type_1.json()) + manifest.sources["NewContract_1.json"] = Source(content=contract_type_1.model_dump_json()) return manifest @@ -102,10 +102,10 @@ def project_without_deployments(project): def _make_new_contract(existing_contract: ContractType, name: str): - source_text = existing_contract.json() + source_text = existing_contract.model_dump_json() source_text = source_text.replace(f"{existing_contract.name}.vy", f"{name}.json") source_text = source_text.replace(existing_contract.name or "", name) - return ContractType.parse_raw(source_text) + return ContractType.model_validate_json(source_text) def test_extract_manifest(project_with_dependency_config): @@ -132,10 +132,12 @@ def test_cached_manifest_when_sources_missing( cache_location.touch() name = "NOTEXISTS" source_id = f"{name}.json" - contract_type = ContractType.parse_obj({"contractName": name, "abi": [], "sourceId": source_id}) + contract_type = ContractType.model_validate( + {"contractName": name, "abi": [], "sourceId": source_id} + ) path = ape_project._cache_folder / source_id - path.write_text(contract_type.json()) - cache_location.write_text(manifest_with_non_existent_sources.json()) + path.write_text(contract_type.model_dump_json()) + cache_location.write_text(manifest_with_non_existent_sources.model_dump_json()) manifest = ape_project.cached_manifest @@ -158,7 +160,7 @@ def test_create_manifest_when_file_changed_with_cached_references_that_no_longer ape_project._cache_folder.mkdir(exist_ok=True) cache_location.touch() - cache_location.write_text(manifest_with_non_existent_sources.json()) + cache_location.write_text(manifest_with_non_existent_sources.model_dump_json()) # Change content source_text = existing_source_path.read_text() @@ -241,7 +243,7 @@ def test_track_deployment( expected_uri = f"blockchain://{bip122_chain_id}/block/{expected_block_hash}" expected_name = contract.contract_type.name expected_code = contract.contract_type.runtime_bytecode - actual_from_file = EthPMContractInstance.parse_raw(deployment_path.read_text()) + actual_from_file = EthPMContractInstance.model_validate_json(deployment_path.read_text()) actual_from_class = project_without_deployments.tracked_deployments[expected_uri][name] assert actual_from_file.address == actual_from_class.address == address @@ -278,7 +280,7 @@ def test_track_deployment_from_previously_deployed_contract( expected_uri = f"blockchain://{bip122_chain_id}/block/{expected_block_hash}" expected_name = contract.contract_type.name expected_code = contract.contract_type.runtime_bytecode - actual_from_file = EthPMContractInstance.parse_raw(path.read_text()) + actual_from_file = EthPMContractInstance.model_validate_json(path.read_text()) actual_from_class = project_without_deployments.tracked_deployments[expected_uri][name] assert actual_from_file.address == actual_from_class.address == address assert actual_from_file.contract_type == actual_from_class.contract_type == expected_name @@ -321,7 +323,7 @@ def test_track_deployment_from_unknown_contract_given_txn_hash( contract = Contract(address, txn_hash=txn_hash) project.track_deployment(contract) path = base_deployments_path / f"{contract.contract_type.name}.json" - actual = EthPMContractInstance.parse_raw(path.read_text()) + actual = EthPMContractInstance.model_validate_json(path.read_text()) assert actual.address == address assert actual.contract_type == contract.contract_type.name assert actual.transaction == txn_hash diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index 9df5d27dad..b3d4e38661 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -71,11 +71,11 @@ def test_chain_id_is_cached(eth_tester_provider): _ = eth_tester_provider.chain_id # Unset `_web3` to show that it is not used in a second call to `chain_id`. - web3 = eth_tester_provider._web3 - eth_tester_provider._web3 = None + web3 = eth_tester_provider.cached_web3 + eth_tester_provider.cached_web3 = None chain_id = eth_tester_provider.chain_id assert chain_id == DEFAULT_TEST_CHAIN_ID - eth_tester_provider._web3 = web3 # Undo + eth_tester_provider.cached_web3 = web3 # Undo def test_chain_id_when_disconnected(eth_tester_provider): diff --git a/tests/functional/test_query.py b/tests/functional/test_query.py index 865e89c6c3..0c4a77c195 100644 --- a/tests/functional/test_query.py +++ b/tests/functional/test_query.py @@ -77,7 +77,7 @@ class Model(BaseInterfaceModel): def test_column_expansion(): columns = validate_and_expand_columns(["*"], Model) - assert columns == list(Model.__fields__) + assert columns == list(Model.model_fields) def test_column_validation(eth_tester_provider, ape_caplog): diff --git a/tests/functional/test_receipt.py b/tests/functional/test_receipt.py index 336d8d867c..328d126178 100644 --- a/tests/functional/test_receipt.py +++ b/tests/functional/test_receipt.py @@ -196,5 +196,5 @@ def test_track_coverage(deploy_receipt, mocker): def test_access_from_tx(deploy_receipt): - actual = deploy_receipt.receiver + actual = deploy_receipt.receipt assert actual == "" diff --git a/tests/functional/test_transaction.py b/tests/functional/test_transaction.py index b562f242d5..ded4149c2d 100644 --- a/tests/functional/test_transaction.py +++ b/tests/functional/test_transaction.py @@ -1,5 +1,5 @@ import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ape_ethereum.transactions import DynamicFeeTransaction, StaticFeeTransaction, TransactionType @@ -64,17 +64,17 @@ def test_txn_hash_and_receipt(owner, eth_tester_provider, ethereum): def test_whitespace_in_transaction_data(): data = b"Should not clip whitespace\t\n" txn_dict = {"data": data} - txn = StaticFeeTransaction.parse_obj(txn_dict) + txn = StaticFeeTransaction.model_validate(txn_dict) assert txn.data == data, "Whitespace should not be removed from data" def test_transaction_dict_excludes_none_values(): txn = StaticFeeTransaction() txn.value = 1000000 - actual = txn.dict() + actual = txn.model_dump(mode="json") assert "value" in actual txn.value = None # type: ignore - actual = txn.dict() + actual = txn.model_dump(mode="json") assert "value" not in actual diff --git a/tests/functional/test_types.py b/tests/functional/test_types.py index 2c0ec7d620..6ebab059eb 100644 --- a/tests/functional/test_types.py +++ b/tests/functional/test_types.py @@ -50,7 +50,7 @@ def log(): def test_contract_log_serialization(log, zero_address): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert obj.contract_address == zero_address assert obj.block_hash == BLOCK_HASH assert obj.block_number == BLOCK_NUMBER @@ -61,7 +61,7 @@ def test_contract_log_serialization(log, zero_address): def test_contract_log_serialization_with_hex_strings_and_non_checksum_addresses(log, zero_address): - data = log.dict() + data = log.model_dump(mode="json") data["log_index"] = to_hex(log.log_index) data["transaction_index"] = to_hex(log.transaction_index) data["block_number"] = to_hex(log.block_number) @@ -79,12 +79,12 @@ def test_contract_log_serialization_with_hex_strings_and_non_checksum_addresses( def test_contract_log_str(log): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert str(obj) == "MyEvent(foo=0 bar=1)" def test_contract_log_repr(log): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert repr(obj) == "" @@ -96,7 +96,7 @@ def test_contract_log_access(log): def test_topic_filter_encoding(): - event_abi = EventABI.parse_raw(RAW_EVENT_ABI) + event_abi = EventABI.model_validate_json(RAW_EVENT_ABI) log_filter = LogFilter.from_event( event=event_abi, search_topics={"newVersion": "0x8c44Cc5c0f5CD2f7f17B9Aca85d456df25a61Ae8"} ) diff --git a/tests/functional/utils/test_abi.py b/tests/functional/utils/test_abi.py index 4015fd45e4..260529447e 100644 --- a/tests/functional/utils/test_abi.py +++ b/tests/functional/utils/test_abi.py @@ -1,5 +1,5 @@ import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from ethpm_types.abi import EventABI, EventABIType from ape.utils.abi import LogInputABICollection diff --git a/tests/functional/utils/test_misc.py b/tests/functional/utils/test_misc.py index 1e1cee2330..8bef1e1301 100644 --- a/tests/functional/utils/test_misc.py +++ b/tests/functional/utils/test_misc.py @@ -1,5 +1,5 @@ import pytest -from ethpm_types import HexBytes +from eth_pydantic_types import HexBytes from packaging.version import Version from web3.types import Wei