From be5d04e4744dba2bde58140059df9af199a8498a Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Mon, 2 Oct 2023 14:40:43 -0500 Subject: [PATCH] feat: pydantic v2 --- setup.py | 7 +-- src/ape/_cli.py | 2 +- src/ape/_pydantic_compat.py | 47 ----------------- src/ape/api/config.py | 13 ++--- src/ape/api/networks.py | 6 +-- src/ape/api/projects.py | 19 +++---- src/ape/api/providers.py | 33 ++++++------ src/ape/api/query.py | 10 ++-- src/ape/api/transactions.py | 28 +++++------ src/ape/contracts/base.py | 8 +-- src/ape/managers/chain.py | 18 +++---- src/ape/managers/config.py | 36 +++++++------ src/ape/managers/converters.py | 38 ++++++++++++-- src/ape/managers/project/dependency.py | 11 ++-- src/ape/managers/project/manager.py | 12 +++-- src/ape/managers/project/types.py | 4 +- src/ape/types/__init__.py | 22 +++----- src/ape/types/address.py | 19 ++++++- src/ape/types/coverage.py | 24 ++++----- src/ape/types/trace.py | 50 +++++++++---------- src/ape/utils/abi.py | 14 ++++-- src/ape/utils/basemodel.py | 18 ++----- src/ape_cache/query.py | 2 +- src/ape_compile/__init__.py | 7 +-- src/ape_ethereum/ecosystem.py | 19 +++---- src/ape_ethereum/multicall/handlers.py | 2 +- src/ape_ethereum/transactions.py | 18 +++---- src/ape_geth/__init__.py | 2 +- src/ape_geth/provider.py | 12 ++--- src/ape_plugins/_cli.py | 2 +- src/ape_plugins/utils.py | 7 +-- src/ape_test/__init__.py | 3 +- src/ape_test/accounts.py | 2 +- src/ape_test/provider.py | 9 ++-- tests/functional/conftest.py | 5 +- .../conversion/test_encode_structs.py | 5 +- tests/functional/test_block.py | 4 +- tests/functional/test_block_container.py | 2 +- tests/functional/test_config.py | 6 +-- tests/functional/test_contract.py | 7 ++- tests/functional/test_contract_instance.py | 9 ++-- tests/functional/test_ecosystem.py | 6 +-- tests/functional/test_plugins.py | 2 +- tests/functional/test_project.py | 26 +++++----- tests/functional/test_query.py | 2 +- tests/functional/test_transaction.py | 6 +-- tests/functional/test_types.py | 10 ++-- 47 files changed, 294 insertions(+), 320 deletions(-) delete mode 100644 src/ape/_pydantic_compat.py diff --git a/setup.py b/setup.py index 77876ff758..ed900a6040 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,8 @@ "packaging>=23.0,<24", "pandas>=1.3.0,<2", "pluggy>=1.3,<2", - "pydantic>=1.10.8,<3", + "pydantic>=2.4.0,<3", + "pydantic-settings>=2.0.3,<3", "PyGithub>=1.59,<2", "pytest>=6.0,<8.0", "python-dateutil>=2.8.2,<3", @@ -124,8 +125,8 @@ "web3[tester]>=6.7.0,<7", # ** Dependencies maintained by ApeWorX ** "eip712>=0.2.1,<0.3", - "ethpm-types>=0.5.6,<0.6", - "evm-trace>=0.1.0a23", + "ethpm-types>=0.6.0,<0.7", + "evm-trace>=0.1.0a26", ], entry_points={ "console_scripts": ["ape=ape._cli:cli"], diff --git a/src/ape/_cli.py b/src/ape/_cli.py index f51778e1cd..82c9985bbb 100644 --- a/src/ape/_cli.py +++ b/src/ape/_cli.py @@ -23,7 +23,7 @@ def display_config(ctx, param, value): from ape import project click.echo("# Current configuration") - click.echo(yaml.dump(project.config_manager.dict())) + click.echo(yaml.dump(project.config_manager.model_dump(mode="json"))) ctx.exit() # NOTE: Must exit to bypass running ApeCLI diff --git a/src/ape/_pydantic_compat.py b/src/ape/_pydantic_compat.py deleted file mode 100644 index 16bfe4b3a8..0000000000 --- a/src/ape/_pydantic_compat.py +++ /dev/null @@ -1,47 +0,0 @@ -# support both pydantic v1 and v2 - -try: - from pydantic.v1 import ( # type: ignore - BaseModel, - BaseSettings, - Extra, - Field, - FileUrl, - HttpUrl, - NonNegativeInt, - PositiveInt, - ValidationError, - root_validator, - validator, - ) - from pydantic.v1.dataclasses import dataclass # type: ignore -except ImportError: - from pydantic import ( # type: ignore - BaseModel, - BaseSettings, - Extra, - Field, - FileUrl, - HttpUrl, - NonNegativeInt, - PositiveInt, - ValidationError, - root_validator, - validator, - ) - from pydantic.dataclasses import dataclass # type: ignore - -__all__ = ( - "BaseModel", - "BaseSettings", - "dataclass", - "Extra", - "Field", - "FileUrl", - "HttpUrl", - "NonNegativeInt", - "PositiveInt", - "ValidationError", - "root_validator", - "validator", -) diff --git a/src/ape/api/config.py b/src/ape/api/config.py index b08a6a8086..fe4264c3ef 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -1,7 +1,8 @@ from enum import Enum from typing import Any, Dict, Optional, TypeVar -from ape._pydantic_compat import BaseModel, BaseSettings +from pydantic import ConfigDict +from pydantic_settings import BaseSettings T = TypeVar("T") @@ -14,10 +15,6 @@ class ConfigEnum(str, Enum): """ -class ConfigDict(BaseModel): - __root__: dict = {} - - class PluginConfig(BaseSettings): """ A base plugin configuration class. Each plugin that includes @@ -26,7 +23,7 @@ class PluginConfig(BaseSettings): @classmethod def from_overrides(cls, overrides: Dict) -> "PluginConfig": - default_values = cls().dict() + default_values = cls().model_dump() def update(root: Dict, value_map: Dict): for key, val in value_map.items(): @@ -54,9 +51,7 @@ def get(self, key: str, default: Optional[T] = None) -> T: return self.__dict__.get(key, default) -class GenericConfig(PluginConfig): +class GenericConfig(ConfigDict): """ The default class used when no specialized class is used. """ - - __root__: dict = {} diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 5d7e9ab40f..52699bf23d 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -9,10 +9,9 @@ serializable_unsigned_transaction_from_dict, ) from eth_utils import keccak, to_int -from ethpm_types import ContractType, HexBytes +from ethpm_types import BaseModel, ContractType, HexBytes from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI -from ape._pydantic_compat import BaseModel from ape.exceptions import ( NetworkError, NetworkMismatchError, @@ -139,8 +138,7 @@ def serialize_transaction(self, transaction: "TransactionAPI") -> bytes: if not self.signature: raise SignatureError("The transaction is not signed.") - txn_data = self.dict(exclude={"sender"}) - + txn_data = self.model_dump(exclude={"sender"}) unsigned_txn = serializable_unsigned_transaction_from_dict(txn_data) signature = ( self.signature.v, diff --git a/src/ape/api/projects.py b/src/ape/api/projects.py index b386257682..edfd5817b2 100644 --- a/src/ape/api/projects.py +++ b/src/ape/api/projects.py @@ -7,10 +7,10 @@ from ethpm_types import Checksum, ContractType, PackageManifest, Source from ethpm_types.manifest import PackageName from ethpm_types.source import Content -from ethpm_types.utils import Algorithm, AnyUrl, compute_checksum +from ethpm_types.utils import Algorithm, compute_checksum from packaging.version import InvalidVersion, Version +from pydantic import AnyUrl, ValidationError -from ape._pydantic_compat import ValidationError from ape.logging import logger from ape.utils import ( BaseInterfaceModel, @@ -111,7 +111,7 @@ def cached_manifest(self) -> Optional[PackageManifest]: continue path = self._cache_folder / f"{contract_type.name}.json" - path.write_text(contract_type.json()) + path.write_text(contract_type.model_dump_json()) # Rely on individual cache files. self._contracts = manifest.contract_types @@ -144,9 +144,10 @@ def contracts(self) -> Dict[str, ContractType]: continue contract_name = p.stem - contract_type = ContractType.parse_file(p) - if contract_type.name is None: - contract_type.name = contract_name + contract_type = ContractType.model_validate_json(p.read_text()) + contract_type.name = ( + contract_name if contract_type.name is None else contract_type.name + ) contracts[contract_type.name] = contract_type self._contracts = contracts @@ -215,7 +216,7 @@ def _create_source_dict( hash=compute_checksum(source_path.read_bytes()), ), urls=[], - content=Content(__root__={i + 1: x for i, x in enumerate(text.splitlines())}), + content=Content(root={i + 1: x for i, x in enumerate(text.splitlines())}), imports=source_imports.get(key, []), references=source_references.get(key, []), ) @@ -431,7 +432,7 @@ def _get_sources(self, project: ProjectAPI) -> List[Path]: def _write_manifest_to_cache(self, manifest: PackageManifest): self._target_manifest_cache_file.unlink(missing_ok=True) self._target_manifest_cache_file.parent.mkdir(exist_ok=True, parents=True) - self._target_manifest_cache_file.write_text(manifest.json()) + self._target_manifest_cache_file.write_text(manifest.model_dump_json()) self._cached_manifest = manifest @@ -440,7 +441,7 @@ def _load_manifest_from_file(file_path: Path) -> Optional[PackageManifest]: return None try: - return PackageManifest.parse_file(file_path) + return PackageManifest.model_validate_json(file_path.read_text()) except ValidationError as err: logger.warning(f"Existing manifest file '{file_path}' corrupted. Re-building.") logger.debug(str(err)) diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index 5e52f79671..8b0d6a2def 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -21,12 +21,12 @@ from ethpm_types import HexBytes from evm_trace import CallTreeNode as EvmCallTreeNode from evm_trace import TraceFrame as EvmTraceFrame +from pydantic import Field, model_validator from web3 import Web3 from web3.exceptions import ContractLogicError as Web3ContractLogicError from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound from web3.types import RPCEndpoint, TxParams -from ape._pydantic_compat import Field, root_validator, validator from ape.api.config import PluginConfig from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI from ape.api.query import BlockTransactionQuery @@ -95,20 +95,12 @@ class BlockAPI(BaseInterfaceModel): def datetime(self) -> datetime.datetime: return datetime.datetime.fromtimestamp(self.timestamp, tz=datetime.timezone.utc) - @root_validator(pre=True) + @model_validator(mode="before") def convert_parent_hash(cls, data): parent_hash = data.get("parent_hash", data.get("parentHash")) or EMPTY_BYTES32 data["parentHash"] = parent_hash return data - @validator("hash", "parent_hash", pre=True) - def validate_hexbytes(cls, value): - # NOTE: pydantic treats these values as bytes and throws an error - if value and not isinstance(value, bytes): - return HexBytes(value) - - return value - @cached_property def transactions(self) -> List[TransactionAPI]: query = BlockTransactionQuery(columns=["*"], block_id=self.hash) @@ -867,7 +859,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: return the block maximum gas limit. """ - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") # Force the use of hex values to support a wider range of nodes. if isinstance(txn_dict.get("type"), int): @@ -1149,7 +1141,7 @@ def _eth_call(self, arguments: List) -> bytes: return HexBytes(result) def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List: - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") fields_to_convert = ("data", "chainId", "value") for field in fields_to_convert: value = txn_dict.get(field) @@ -1212,7 +1204,9 @@ def get_receipt( except TimeExhausted as err: raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err - network_config: Dict = self.network.config.dict().get(self.network.name, {}) + network_config: Dict = self.network.config.model_dump(mode="json").get( + self.network.name, {} + ) max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX) txn = {} for attempt in range(max_retries): @@ -1382,10 +1376,10 @@ def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: def fetch_log_page(block_range): start, stop = block_range - page_filter = log_filter.copy(update=dict(start_block=start, stop_block=stop)) + page_filter = log_filter.model_copy(update=dict(start_block=start, stop_block=stop)) # eth-tester expects a different format, let web3 handle the conversions for it raw = "EthereumTester" not in self.client_version - logs = self._get_logs(page_filter.dict(), raw) + logs = self._get_logs(page_filter.model_dump(mode="json"), raw) return self.network.ecosystem.decode_logs(logs, *log_filter.events) with ThreadPoolExecutor(self.concurrency) as pool: @@ -1455,7 +1449,8 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: if txn.sender not in self.web3.eth.accounts: self.chain_manager.provider.unlock_account(txn.sender) - txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn.dict())) + txn_data = cast(TxParams, txn.model_dump(mode="json")) + txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn_data)) except (ValueError, Web3ContractLogicError) as err: vm_err = self.get_virtual_machine_error(err, txn=txn) @@ -1474,7 +1469,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: self.chain_manager.history.append(receipt) if receipt.failed: - txn_dict = receipt.transaction.dict() + txn_dict = receipt.transaction.model_dump(mode="json") txn_params = cast(TxParams, txn_dict) # Replay txn to get revert reason @@ -1508,7 +1503,7 @@ def _create_call_tree_node( inputs=evm_call.calldata if "CREATE" in call_type else evm_call.calldata[4:].hex(), method_id=evm_call.calldata[:4].hex(), outputs=evm_call.returndata.hex(), - raw=evm_call.dict(), + raw=evm_call.model_dump(mode="json"), txn_hash=txn_hash, ) @@ -1531,7 +1526,7 @@ def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame: gas_cost=evm_frame.gas_cost, depth=evm_frame.depth, contract_address=address, - raw=evm_frame.dict(), + raw=evm_frame.model_dump(mode="json"), ) def _make_request(self, endpoint: str, parameters: List) -> Any: diff --git a/src/ape/api/query.py b/src/ape/api/query.py index 3067d505bd..5b9ae8ced4 100644 --- a/src/ape/api/query.py +++ b/src/ape/api/query.py @@ -1,9 +1,9 @@ from functools import lru_cache from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Type, Union -from ethpm_types.abi import EventABI, MethodABI +from ethpm_types.abi import BaseModel, EventABI, MethodABI +from pydantic import NonNegativeInt, PositiveInt, model_validator -from ape._pydantic_compat import BaseModel, NonNegativeInt, PositiveInt, root_validator from ape.api.transactions import ReceiptAPI, TransactionAPI from ape.logging import logger from ape.types import AddressType @@ -22,7 +22,7 @@ # TODO: Replace with `functools.cache` when Py3.8 dropped @lru_cache(maxsize=None) def _basic_columns(Model: Type[BaseInterfaceModel]) -> Set[str]: - columns = set(Model.__fields__) + columns = set(Model.model_fields) # TODO: Remove once `ReceiptAPI` fields cleaned up for better processing if Model == ReceiptAPI: @@ -104,7 +104,7 @@ class _BaseBlockQuery(_BaseQuery): stop_block: NonNegativeInt step: PositiveInt = 1 - @root_validator(pre=True) + @model_validator(mode="before") def check_start_block_before_stop_block(cls, values): if values["stop_block"] < values["start_block"]: raise ValueError( @@ -141,7 +141,7 @@ class AccountTransactionQuery(_BaseQuery): start_nonce: NonNegativeInt = 0 stop_nonce: NonNegativeInt - @root_validator(pre=True) + @model_validator(mode="before") def check_start_nonce_before_stop_nonce(cls, values: Dict) -> Dict: if values["stop_nonce"] < values["start_nonce"]: raise ValueError( diff --git a/src/ape/api/transactions.py b/src/ape/api/transactions.py index 71b82fe805..8df453ec78 100644 --- a/src/ape/api/transactions.py +++ b/src/ape/api/transactions.py @@ -6,9 +6,10 @@ from eth_utils import is_0x_prefixed, is_hex, to_int from ethpm_types import HexBytes from ethpm_types.abi import EventABI, MethodABI +from pydantic import ConfigDict, field_validator +from pydantic.fields import Field from tqdm import tqdm # type: ignore -from ape._pydantic_compat import Field, validator from ape.api.explorers import ExplorerAPI from ape.exceptions import ( NetworkError, @@ -53,7 +54,7 @@ class TransactionAPI(BaseInterfaceModel): gas_limit: Optional[int] = Field(None, alias="gas") nonce: Optional[int] = None # NOTE: `Optional` only to denote using default behavior value: int = 0 - data: bytes = b"" + data: HexBytes = HexBytes("") type: int max_fee: Optional[int] = None max_priority_fee: Optional[int] = None @@ -63,10 +64,9 @@ class TransactionAPI(BaseInterfaceModel): signature: Optional[TransactionSignature] = Field(None, exclude=True) - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) - @validator("gas_limit", pre=True, allow_reuse=True) + @field_validator("gas_limit", mode="before") def validate_gas_limit(cls, value): if value is None: if not cls.network_manager.active_provider: @@ -91,21 +91,21 @@ def validate_gas_limit(cls, value): return value - @validator("max_fee", "max_priority_fee", pre=True, allow_reuse=True) + @field_validator("max_fee", "max_priority_fee", mode="before") def convert_fees(cls, value): if isinstance(value, str): return cls.conversion_manager.convert(value, int) return value - @validator("data", pre=True) + @field_validator("data", mode="before") def validate_data(cls, value): if isinstance(value, str): return HexBytes(value) return value - @validator("value", pre=True) + @field_validator("value", mode="before") def validate_value(cls, value): if isinstance(value, int): return value @@ -169,12 +169,12 @@ def serialize_transaction(self) -> bytes: """ def __repr__(self) -> str: - data = self.dict() + data = self.model_dump(mode="json") params = ", ".join(f"{k}={v}" for k, v in data.items()) return f"<{self.__class__.__name__} {params}>" def __str__(self) -> str: - data = self.dict() + data = self.model_dump(mode="json") if len(data["data"]) > 9: # only want to specify encoding if data["data"] is a string if isinstance(data["data"], str): @@ -269,14 +269,14 @@ def __repr__(self) -> str: def __ape_extra_attributes__(self) -> Iterator[ExtraModelAttributes]: yield ExtraModelAttributes(name="transaction", attributes=self.transaction) - @validator("transaction", pre=True) + @field_validator("transaction", mode="before") def confirm_transaction(cls, value): if isinstance(value, dict): - value = TransactionAPI.parse_obj(value) + value = TransactionAPI.model_validate(value) return value - @validator("txn_hash", pre=True) + @field_validator("txn_hash", mode="before") def validate_txn_hash(cls, value): return HexBytes(value).hex() @@ -472,7 +472,7 @@ def return_value(self) -> Any: @raises_not_implemented def source_traceback(self) -> SourceTraceback: # type: ignore[empty-body] """ - A pythonic style traceback for both failing and non-failing receipts. + A Pythonic style traceback for both failing and non-failing receipts. Requires a provider that implements :meth:~ape.api.providers.ProviderAPI.get_transaction_trace`. """ diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index a463018ea4..04508e0160 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -554,12 +554,8 @@ def query( f"'stop={stop_block}' cannot be greater than " f"the chain length ({self.chain_manager.blocks.height})." ) - - if columns[0] == "*": - columns = list(ContractLog.__fields__) # type: ignore - query: Dict = { - "columns": columns, + "columns": list(ContractLog.model_fields) if columns[0] == "*" else columns, "event": self.abi, "start_block": start_block, "stop_block": stop_block, @@ -631,7 +627,7 @@ def range( addresses = list(set([contract_address] + (extra_addresses or []))) contract_event_query = ContractEventQuery( - columns=list(ContractLog.__fields__.keys()), + columns=list(ContractLog.model_fields.keys()), contract=addresses, event=self.abi, search_topics=search_topics, diff --git a/src/ape/managers/chain.py b/src/ape/managers/chain.py index f46dea2726..5107931dcc 100644 --- a/src/ape/managers/chain.py +++ b/src/ape/managers/chain.py @@ -223,7 +223,7 @@ def range( # Note: the range `stop_block` is a non-inclusive stop, while the # `.query` method uses an inclusive stop, so we must adjust downwards. query = BlockQuery( - columns=list(self.head.__fields__), # TODO: fetch the block fields from EcosystemAPI + columns=list(self.head.model_fields), # TODO: fetch the block fields from EcosystemAPI start_block=start, stop_block=stop - 1, step=step, @@ -524,7 +524,7 @@ def __getitem_int(self, index: int) -> ReceiptAPI: next( self.query_manager.query( AccountTransactionQuery( - columns=list(ReceiptAPI.__fields__), + columns=list(ReceiptAPI.model_fields), account=self.address, start_nonce=index, stop_nonce=index, @@ -562,7 +562,7 @@ def __getitem_slice(self, indices: slice) -> List[ReceiptAPI]: list( self.query_manager.query( AccountTransactionQuery( - columns=list(ReceiptAPI.__fields__), + columns=list(ReceiptAPI.model_fields), account=self.address, start_nonce=start, stop_nonce=stop - 1, @@ -1343,21 +1343,21 @@ def _get_contract_type_from_disk(self, address: AddressType) -> Optional[Contrac if not address_file.is_file(): return None - return ContractType.parse_file(address_file) + return ContractType.model_validate_json(address_file.read_text()) def _get_proxy_info_from_disk(self, address: AddressType) -> Optional[ProxyInfoAPI]: address_file = self._proxy_info_cache / f"{address}.json" if not address_file.is_file(): return None - return ProxyInfoAPI.parse_file(address_file) + return ProxyInfoAPI.model_validate_json(address_file.read_text()) def _get_blueprint_from_disk(self, blueprint_id: str) -> Optional[ContractType]: contract_file = self._blueprint_cache / f"{blueprint_id}.json" if not contract_file.is_file(): return None - return ContractType.parse_file(contract_file) + return ContractType.model_validate_json(contract_file.read_text()) def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[ContractType]: if not self._network.explorer: @@ -1378,17 +1378,17 @@ def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[Con def _cache_contract_to_disk(self, address: AddressType, contract_type: ContractType): self._contract_types_cache.mkdir(exist_ok=True, parents=True) address_file = self._contract_types_cache / f"{address}.json" - address_file.write_text(contract_type.json()) + address_file.write_text(contract_type.model_dump_json()) def _cache_proxy_info_to_disk(self, address: AddressType, proxy_info: ProxyInfoAPI): self._proxy_info_cache.mkdir(exist_ok=True, parents=True) address_file = self._proxy_info_cache / f"{address}.json" - address_file.write_text(proxy_info.json()) + address_file.write_text(proxy_info.model_dump_json()) def _cache_blueprint_to_disk(self, blueprint_id: str, contract_type: ContractType): self._blueprint_cache.mkdir(exist_ok=True, parents=True) blueprint_file = self._blueprint_cache / f"{blueprint_id}.json" - blueprint_file.write_text(contract_type.json()) + blueprint_file.write_text(contract_type.model_dump_json()) def _load_deployments_cache(self) -> Dict: return ( diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index ba10efc5de..0d307c2f33 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -3,7 +3,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union -from ape._pydantic_compat import root_validator +from ethpm_types import PackageMeta +from pydantic import RootModel, model_validator + from ape.api import ConfigDict, DependencyAPI, PluginConfig from ape.exceptions import ConfigError from ape.logging import logger @@ -12,7 +14,6 @@ if TYPE_CHECKING: from .project import ProjectManager -from ethpm_types import BaseModel, PackageMeta CONFIG_FILE_NAME = "ape-config.yaml" @@ -27,16 +28,13 @@ class CompilerConfig(PluginConfig): """List of globular files to ignore""" -class DeploymentConfigCollection(BaseModel): - __root__: Dict - - @root_validator(pre=True) - def validate_deployments(cls, data: Dict): - root_data = data.get("__root__", data) - valid_ecosystems = root_data.pop("valid_ecosystems", {}) - valid_networks = root_data.pop("valid_networks", {}) +class DeploymentConfigCollection(RootModel[dict]): + @model_validator(mode="before") + def validate_deployments(cls, data: Dict, info): + valid_ecosystems = data.pop("valid_ecosystems", {}) + valid_networks = data.pop("valid_networks", {}) valid_data: Dict = {} - for ecosystem_name, networks in root_data.items(): + for ecosystem_name, networks in data.items(): if ecosystem_name not in valid_ecosystems: logger.warning(f"Invalid ecosystem '{ecosystem_name}' in deployments config.") continue @@ -69,7 +67,7 @@ def validate_deployments(cls, data: Dict): network_name: valid_deployments, } - return {"__root__": valid_data} + return valid_data class ConfigManager(BaseInterfaceModel): @@ -128,9 +126,9 @@ class ConfigManager(BaseInterfaceModel): _cached_configs: Dict[str, Dict[str, Any]] = {} - @root_validator(pre=True) + @model_validator(mode="before") def check_config_for_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: - extra = [key for key in values.keys() if key not in cls.__fields__] + extra = [key for key in values.keys() if key not in cls.model_fields] if extra: logger.warning(f"Unprocessed extra config fields not set '{extra}'.") @@ -149,11 +147,11 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: self.name = cache.get("name", "") self.version = cache.get("version", "") self.default_ecosystem = cache.get("default_ecosystem", "ethereum") - self.meta = PackageMeta.parse_obj(cache.get("meta", {})) + self.meta = PackageMeta.model_validate(cache.get("meta", {})) self.dependencies = cache.get("dependencies", []) self.deployments = cache.get("deployments", {}) self.contracts_folder = cache.get("contracts_folder", self.PROJECT_FOLDER / "contracts") - self.compiler = CompilerConfig.parse_obj(cache.get("compiler", {})) + self.compiler = CompilerConfig.model_validate(cache.get("compiler", {})) return cache # First, load top-level configs. Then, load all the plugin configs. @@ -172,13 +170,13 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: self.name = configs["name"] = user_config.pop("name", "") self.version = configs["version"] = user_config.pop("version", "") meta_dict = user_config.pop("meta", {}) - meta_obj = PackageMeta.parse_obj(meta_dict) + meta_obj = PackageMeta.model_validate(meta_dict) configs["meta"] = meta_dict self.meta = meta_obj self.default_ecosystem = configs["default_ecosystem"] = user_config.pop( "default_ecosystem", "ethereum" ) - compiler_dict = user_config.pop("compiler", CompilerConfig().dict()) + compiler_dict = user_config.pop("compiler", CompilerConfig().model_dump(mode="json")) configs["compiler"] = compiler_dict self.compiler = CompilerConfig(**compiler_dict) @@ -203,7 +201,7 @@ def _plugin_configs(self) -> Dict[str, PluginConfig]: valid_ecosystems = dict(self.plugin_manager.ecosystems) valid_network_names = [n[1] for n in [e[1] for e in self.plugin_manager.networks]] self.deployments = configs["deployments"] = DeploymentConfigCollection( - __root__={ + root={ **deployments, "valid_ecosystems": valid_ecosystems, "valid_networks": valid_network_names, diff --git a/src/ape/managers/converters.py b/src/ape/managers/converters.py index 199c70ca9f..dd9f717f74 100644 --- a/src/ape/managers/converters.py +++ b/src/ape/managers/converters.py @@ -372,11 +372,39 @@ def convert_method_args( return self.convert(pre_processed_args, tuple) def convert_method_kwargs(self, kwargs) -> Dict: - fields = TransactionAPI.__fields__ + fields = TransactionAPI.model_fields + def get_real_type(type_): + all_types = getattr(type_, "_typevar_types", []) + if not all_types or not isinstance(all_types, (list, tuple)): + return type_ + + # Filter out None + valid_types = [t for t in all_types if t is not None] + if len(valid_types) == 1: + # This is something like Optional[int], + # however, if the user provides a value, + # we want to convert to the non-optional type. + return valid_types[0] + + # Not sure if this is possible; the converter may fail. + return valid_types + + annotations = {name: get_real_type(f.annotation) for name, f in fields.items()} kwargs_to_convert = {k: v for k, v in kwargs.items() if k == "sender" or k in fields} - converted_fields = { - k: self.convert(v, AddressType if k == "sender" else fields[k].type_) - for k, v in kwargs_to_convert.items() - } + converted_fields = {} + for field_name, value in kwargs_to_convert.items(): + type_ = AddressType if field_name == "sender" else annotations.get(field_name) + if type_: + try: + converted_value = self.convert(value, type_) + except ConversionError: + # Ignore conversion errors and use the values as-is. + converted_value = value + + else: + converted_value = value + + converted_fields[field_name] = converted_value + return {**kwargs, **converted_fields} diff --git a/src/ape/managers/project/dependency.py b/src/ape/managers/project/dependency.py index d4bdf4ff29..254723c0e8 100644 --- a/src/ape/managers/project/dependency.py +++ b/src/ape/managers/project/dependency.py @@ -6,10 +6,9 @@ from typing import Dict, Iterable, List, Optional, Type from ethpm_types import PackageManifest -from ethpm_types.utils import AnyUrl +from pydantic import AnyUrl, FileUrl, HttpUrl, model_validator from semantic_version import NpmSpec, Version # type: ignore -from ape._pydantic_compat import FileUrl, HttpUrl, root_validator from ape.api import DependencyAPI from ape.exceptions import ProjectError, UnknownVersionError from ape.logging import logger @@ -236,7 +235,7 @@ def uri(self) -> AnyUrl: elif self._reference: _uri = f"{_uri}/tree/{self._reference}" - return HttpUrl(_uri, scheme="https") + return HttpUrl(_uri) def __repr__(self): return f"<{self.__class__.__name__} github={self.github}>" @@ -289,7 +288,7 @@ class LocalDependency(DependencyAPI): local: str version: str = "local" - @root_validator() + @model_validator(mode="before") def validate_contracts_folder(cls, value): if value.get("contracts_folder") not in (None, "contracts"): return value @@ -320,7 +319,7 @@ def version_id(self) -> str: @property def uri(self) -> AnyUrl: - return FileUrl(self.path.as_uri(), scheme="file") + return FileUrl(self.path.as_uri()) def extract_manifest(self, use_cache: bool = True) -> PackageManifest: return self._extract_local_manifest(self.path, use_cache=use_cache) @@ -397,7 +396,7 @@ def version_from_local_json(self) -> Optional[str]: @property def uri(self) -> AnyUrl: _uri = f"https://www.npmjs.com/package/{self.npm}/v/{self.version}" - return HttpUrl(_uri, scheme="https") + return HttpUrl(_uri) def extract_manifest(self, use_cache: bool = True) -> PackageManifest: if use_cache and self.cached_manifest: diff --git a/src/ape/managers/project/manager.py b/src/ape/managers/project/manager.py index 69dc1ec921..061a8ad870 100644 --- a/src/ape/managers/project/manager.py +++ b/src/ape/managers/project/manager.py @@ -7,7 +7,8 @@ from ethpm_types.contract_type import BIP122_URI from ethpm_types.manifest import PackageName from ethpm_types.source import Compiler, ContractSource -from ethpm_types.utils import AnyUrl, Hex +from ethpm_types.utils import Hex +from pydantic import AnyUrl from ape.api import DependencyAPI, ProjectAPI from ape.api.networks import LOCAL_NETWORK_NAME @@ -232,7 +233,8 @@ def tracked_deployments(self) -> Dict[BIP122_URI, Dict[str, EthPMContractInstanc for ecosystem_path in [x for x in self._package_deployments_folder.iterdir() if x.is_dir()]: for deployment_path in [x for x in ecosystem_path.iterdir() if x.suffix == ".json"]: - ethpm_instance = EthPMContractInstance.parse_file(deployment_path) + text = deployment_path.read_text() + ethpm_instance = EthPMContractInstance.model_validate_json(text) if not ethpm_instance: continue @@ -458,7 +460,7 @@ def __getattr__(self, attr_name: str) -> Any: # We know if we get here that the path does not exist. path = self.local_project._cache_folder / f"{ct.name}.json" - path.write_text(ct.json()) + path.write_text(ct.model_dump_json()) if self.local_project._contracts is None: self.local_project._contracts = {ct.name: ct} else: @@ -743,7 +745,7 @@ def track_deployment(self, contract: ContractInstance): logger.debug("Deployment already tracked. Re-tracking.") destination.unlink() - destination.write_text(artifact.json()) + destination.write_text(artifact.model_dump_json()) def _create_contract_source(self, contract_type: ContractType) -> Optional[ContractSource]: if not (source_id := contract_type.source_id): @@ -772,7 +774,7 @@ def _get_contract(self, name: str) -> Optional[ContractContainer]: return None # def publish_manifest(self): - # manifest = self.manifest.dict() + # manifest = self.manifest.model_dump(mode="json") # if not manifest["name"]: # raise ProjectError("Need name to release manifest") # if not manifest["version"]: diff --git a/src/ape/managers/project/types.py b/src/ape/managers/project/types.py index 705b1a46ad..203b11345b 100644 --- a/src/ape/managers/project/types.py +++ b/src/ape/managers/project/types.py @@ -187,12 +187,12 @@ def create_manifest( version=self.version, ) # Cache the updated manifest so `self.cached_manifest` reads it next time - self.manifest_cachefile.write_text(manifest.json()) + self.manifest_cachefile.write_text(manifest.model_dump_json()) self._cached_manifest = manifest if compiled_contract_types: for name, contract_type in compiled_contract_types.items(): file = self.project_manager.local_project._cache_folder / f"{name}.json" - file.write_text(contract_type.json()) + file.write_text(contract_type.model_dump_json()) self._contracts = self._contracts or {} self._contracts[name] = contract_type diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 131382e26b..0d7915db87 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -32,9 +32,9 @@ ) from ethpm_types.abi import EventABI from ethpm_types.source import Closure +from pydantic import BaseModel, field_validator, model_validator from web3.types import FilterParams -from ape._pydantic_compat import BaseModel, root_validator, validator from ape.types.address import AddressType, RawAddress from ape.types.coverage import ( ContractCoverage, @@ -84,7 +84,7 @@ class AutoGasLimit(BaseModel): A multiplier to estimated gas. """ - @validator("multiplier", pre=True) + @field_validator("multiplier", mode="before") def validate_multiplier(cls, value): if isinstance(value, str): return float(value) @@ -136,7 +136,7 @@ class LogFilter(BaseModel): stop_block: Optional[int] = None # Use block height selectors: Dict[str, EventABI] = {} - @root_validator() + @model_validator(mode="before") def compute_selectors(cls, values): values["selectors"] = { encode_hex(keccak(text=event.selector)): event for event in values.get("events", []) @@ -144,17 +144,11 @@ def compute_selectors(cls, values): return values - @validator("start_block", pre=True) + @field_validator("start_block", mode="before") def validate_start_block(cls, value): return value or 0 - @validator("addresses", pre=True, each_item=True) - def validate_addresses(cls, value): - from ape import convert - - return convert(value, AddressType) - - def dict(self, client=None): + def model_dump(self, *args, **kwargs): _Hash32 = Union[Hash32, HexBytes, HexStr] topics = cast(Sequence[Optional[Union[_Hash32, Sequence[_Hash32]]]], self.topic_filter) return FilterParams( @@ -236,7 +230,7 @@ class BaseContractLog(BaseInterfaceModel): event_arguments: Dict[str, Any] = {} """The arguments to the event, including both indexed and non-indexed data.""" - @validator("contract_address", pre=True) + @field_validator("contract_address", mode="before") def validate_address(cls, value): return cls.conversion_manager.convert(value, AddressType) @@ -275,7 +269,7 @@ class ContractLog(BaseContractLog): Is `None` when from the pending block. """ - @validator("block_number", "log_index", "transaction_index", pre=True) + @field_validator("block_number", "log_index", "transaction_index", mode="before") def validate_hex_ints(cls, value): if value is None: # Should only happen for optionals. @@ -286,7 +280,7 @@ def validate_hex_ints(cls, value): return value - @validator("contract_address", pre=True) + @field_validator("contract_address", mode="before") def validate_address(cls, value): return cls.conversion_manager.convert(value, AddressType) diff --git a/src/ape/types/address.py b/src/ape/types/address.py index 9c724225fa..e2b97a144e 100644 --- a/src/ape/types/address.py +++ b/src/ape/types/address.py @@ -1,13 +1,28 @@ -from typing import Union +from importlib import import_module +from typing import Annotated, Any, Union -from eth_typing import ChecksumAddress as AddressType +from eth_typing import ChecksumAddress from ethpm_types import HexBytes +from pydantic import BeforeValidator RawAddress = Union[str, int, HexBytes] """ A raw data-type representation of an address. """ + +def validate_address(addr: Any) -> ChecksumAddress: + # NOTE: Unable to ape.utils.ZERO_ADDRESS here because of a cyclic import. + return ( + getattr(import_module("ape"), "convert")(addr, AddressType) + if addr + else "0x0000000000000000000000000000000000000000" + ) + + +AddressType = Annotated[ChecksumAddress, BeforeValidator(validate_address)] + + __all__ = [ "AddressType", "RawAddress", diff --git a/src/ape/types/coverage.py b/src/ape/types/coverage.py index 462b6ca1b1..1c85971277 100644 --- a/src/ape/types/coverage.py +++ b/src/ape/types/coverage.py @@ -9,8 +9,8 @@ import requests from ethpm_types import BaseModel from ethpm_types.source import ContractSource, SourceLocation +from pydantic import NonNegativeInt, field_validator -from ape._pydantic_compat import NonNegativeInt, validator from ape.logging import logger from ape.utils.misc import get_current_timestamp_ms from ape.version import version as ape_version @@ -211,8 +211,8 @@ def line_rate(self) -> float: # This function has hittable statements. return self.lines_covered / self.lines_valid if self.lines_valid > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -335,8 +335,8 @@ def __getitem__(self, function_name: str) -> FunctionCoverage: raise IndexError(f"Function '{function_name}' not found.") - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -436,8 +436,8 @@ def function_rate(self) -> float: """ return self.function_hits / self.total_functions if self.total_functions > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -535,8 +535,8 @@ def function_rate(self) -> float: """ return self.function_hits / self.total_functions if self.total_functions > 0 else 0 - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered @@ -575,7 +575,7 @@ class CoverageReport(BaseModel): Each project with individual coverage tracked. """ - @validator("timestamp", pre=True) + @field_validator("timestamp", mode="before") def validate_timestamp(cls, value): # Default to current UTC timestamp (ms). return value or get_current_timestamp_ms() @@ -998,8 +998,8 @@ def _set_common_td(self, tbody_tr: Any, src_or_fn: Any): stmt_cov_td = SubElement(tbody_tr, "td", {}, **{"class": "column4"}) stmt_cov_td.text = f"{round(src_or_fn.line_rate * 100, 2)}%" - def dict(self, *args, **kwargs) -> dict: - attribs = super().dict(*args, **kwargs) + def model_dump(self, *args, **kwargs) -> dict: + attribs = super().model_dump(*args, **kwargs) # Add coverage stats. attribs["lines_covered"] = self.lines_covered diff --git a/src/ape/types/trace.py b/src/ape/types/trace.py index ccc613f152..852c3368d7 100644 --- a/src/ape/types/trace.py +++ b/src/ape/types/trace.py @@ -6,10 +6,10 @@ from ethpm_types.ast import SourceLocation from ethpm_types.source import Closure, Content, Function, SourceStatement, Statement from evm_trace.gas import merge_reports +from pydantic import Field, RootModel from rich.table import Table from rich.tree import Tree -from ape._pydantic_compat import Field from ape.types.address import AddressType from ape.utils.basemodel import BaseInterfaceModel from ape.utils.misc import is_evm_precompile, is_zero_hex @@ -323,9 +323,9 @@ def line_numbers(self) -> List[int]: def content(self) -> Content: result: Dict[int, str] = {} for node in self.source_statements: - result = {**result, **node.content.__root__} + result = {**result, **node.content.root} - return Content(__root__=result) + return Content(root=result) @property def source_header(self) -> str: @@ -409,7 +409,7 @@ def extend( new_lines = {no: ln.rstrip() for no, ln in content.items() if no >= content_start} if new_lines: # Add the next statement in this sequence. - content = Content(__root__=new_lines) + content = Content(root=new_lines) statement = SourceStatement(asts=asts, content=content, pcs=pcs) self.statements.append(statement) @@ -471,23 +471,21 @@ def next_statement(self) -> Optional[SourceStatement]: content_dict = {} for ast in next_stmt_asts: sub_content = function.get_content(ast.line_numbers) - content_dict = {**sub_content.__root__} + content_dict = {**sub_content.root} if not content_dict: return None sorted_dict = {k: content_dict[k] for k in sorted(content_dict)} - content = Content(__root__=sorted_dict) + content = Content(root=sorted_dict) return SourceStatement(asts=next_stmt_asts, content=content) -class SourceTraceback(BaseModel): +class SourceTraceback(RootModel[List[ControlFlow]]): """ A full execution traceback including source code. """ - __root__: List[ControlFlow] - @classmethod def create( cls, @@ -497,41 +495,41 @@ def create( ): trace, second_trace = tee(trace) if not second_trace or not (accessor := next(second_trace, None)): - return cls.parse_obj([]) + return cls.model_validate([]) if not (source_id := contract_type.source_id): - return cls.parse_obj([]) + return cls.model_validate([]) ext = f".{source_id.split('.')[-1]}" if ext not in accessor.compiler_manager.registered_compilers: - return cls.parse_obj([]) + return cls.model_validate([]) compiler = accessor.compiler_manager.registered_compilers[ext] try: return compiler.trace_source(contract_type, trace, HexBytes(data)) except NotImplementedError: - return cls.parse_obj([]) + return cls.model_validate([]) def __str__(self) -> str: return self.format() def __repr__(self) -> str: - return f"" + return f"" def __len__(self) -> int: - return len(self.__root__) + return len(self.root) def __iter__(self) -> Iterator[ControlFlow]: # type: ignore[override] - yield from self.__root__ + yield from self.root def __getitem__(self, idx: int) -> ControlFlow: try: - return self.__root__[idx] + return self.root[idx] except IndexError as err: raise IndexError(f"Control flow index '{idx}' out of range.") from err def __setitem__(self, key, value): - return self.__root__.__setitem__(key, value) + return self.root.__setitem__(key, value) @property def revert_type(self) -> Optional[str]: @@ -546,7 +544,7 @@ def append(self, __object) -> None: """ Append the given control flow to this one. """ - self.__root__.append(__object) + self.root.append(__object) def extend(self, __iterable) -> None: """ @@ -555,14 +553,14 @@ def extend(self, __iterable) -> None: if not isinstance(__iterable, SourceTraceback): raise TypeError("Can only extend another traceback object.") - self.__root__.extend(__iterable.__root__) + self.root.extend(__iterable.root) @property def last(self) -> Optional[ControlFlow]: """ The last control flow in the traceback, if there is one. """ - return self.__root__[-1] if len(self.__root__) else None + return self.root[-1] if len(self.root) else None @property def execution(self) -> List[ControlFlow]: @@ -570,27 +568,27 @@ def execution(self) -> List[ControlFlow]: All the control flows in order. Each set of statements in a control flow is separated by a jump. """ - return list(self.__root__) + return list(self.root) @property def statements(self) -> List[Statement]: """ All statements from each control flow. """ - return list(chain(*[x.statements for x in self.__root__])) + return list(chain(*[x.statements for x in self.root])) @property def source_statements(self) -> List[SourceStatement]: """ All source statements from each control flow. """ - return list(chain(*[x.source_statements for x in self.__root__])) + return list(chain(*[x.source_statements for x in self.root])) def format(self) -> str: """ Get a formatted traceback string for displaying to users. """ - if not len(self.__root__): + if not len(self.root): # No calls. return "" @@ -598,7 +596,7 @@ def format(self) -> str: indent = " " last_depth = None segments: List[str] = [] - for control_flow in reversed(self.__root__): + for control_flow in reversed(self.root): if last_depth is None or control_flow.depth == last_depth - 1: if control_flow.depth == 0 and len(segments) >= 1: # Ignore 0-layer segments if source code was hit diff --git a/src/ape/utils/abi.py b/src/ape/utils/abi.py index 0a9b21ebde..f8de5bc547 100644 --- a/src/ape/utils/abi.py +++ b/src/ape/utils/abi.py @@ -102,7 +102,7 @@ def _encode(self, _type: ABIType, value: Any): and isinstance(value, (list, tuple)) and len(_type.components or []) > 0 ): - non_array_type_data = _type.dict() + non_array_type_data = _type.model_dump(mode="json") non_array_type_data["type"] = "tuple" non_array_type = ABIType(**non_array_type_data) return [self._encode(non_array_type, v) for v in value] @@ -147,8 +147,12 @@ def _decode( elif has_array_of_tuples_return: item_type_str = str(_types[0].type).split("[")[0] - data = {**_types[0].dict(), "type": item_type_str, "internalType": item_type_str} - output_type = ABIType.parse_obj(data) + data = { + **_types[0].model_dump(mode="json"), + "type": item_type_str, + "internalType": item_type_str, + } + output_type = ABIType.model_validate(data) if isinstance(values, (list, tuple)) and not values[0]: # Only returned an empty list. @@ -166,11 +170,11 @@ def _decode( if item_type_str == "tuple": # Either an array of structs or nested structs. item_type_data = { - **output_type.dict(), + **output_type.model_dump(mode="json"), "type": item_type_str, "internalType": item_type_str, } - item_type = ABIType.parse_obj(item_type_data) + item_type = ABIType.model_validate(item_type_data) if is_struct(output_type): parsed_item = self._decode([item_type], [value]) diff --git a/src/ape/utils/basemodel.py b/src/ape/utils/basemodel.py index 361a9a0786..ddc787c498 100644 --- a/src/ape/utils/basemodel.py +++ b/src/ape/utils/basemodel.py @@ -5,7 +5,6 @@ from ape.exceptions import ApeAttributeError, ProviderNotConnectedError from ape.logging import logger -from ape.utils.misc import cached_property, singledispatchmethod if TYPE_CHECKING: from ape.api.providers import ProviderAPI @@ -134,7 +133,11 @@ class ExtraModelAttributes(_BaseModel): """Whether to use these in ``__getitem__``.""" def __contains__(self, name: str) -> bool: - attr_dict = self.attributes if isinstance(self.attributes, dict) else self.attributes.dict() + attr_dict = ( + self.attributes + if isinstance(self.attributes, dict) + else self.attributes.model_dump(mode="json") + ) if name in attr_dict: return True @@ -248,17 +251,6 @@ class BaseInterfaceModel(BaseInterface, BaseModel): An abstract base-class with manager access on a pydantic base model. """ - class Config: - # NOTE: Due to https://github.com/samuelcolvin/pydantic/issues/1241 we have - # to add this cached property workaround in order to avoid this error: - - # TypeError: cannot pickle '_thread.RLock' object - - keep_untouched = (cached_property, singledispatchmethod) - arbitrary_types_allowed = True - underscore_attrs_are_private = True - copy_on_model_validation = "none" - def __dir__(self) -> List[str]: """ **NOTE**: Should integrate options in IPython tab-completion. diff --git a/src/ape_cache/query.py b/src/ape_cache/query.py index 8cca70eee1..f74231c945 100644 --- a/src/ape_cache/query.py +++ b/src/ape_cache/query.py @@ -352,7 +352,7 @@ def _perform_contract_events_query(self, query: ContractEventQuery) -> Iterator[ # NOTE: Should be unreachable if estimated correctly raise QueryEngineError(f"Could not perform query:\n{query}") - yield from map(lambda row: ContractLog.parse_obj(dict(row.items())), result) + yield from map(lambda row: ContractLog.model_validate(dict(row.items())), result) @singledispatchmethod def _cache_update_clause(self, query: QueryType) -> Insert: diff --git a/src/ape_compile/__init__.py b/src/ape_compile/__init__.py index a829ae2bfc..364807706c 100644 --- a/src/ape_compile/__init__.py +++ b/src/ape_compile/__init__.py @@ -1,7 +1,8 @@ from typing import Any, List, Optional +from pydantic import Field, field_validator + from ape import plugins -from ape._pydantic_compat import Field, validator from ape.api import PluginConfig from ape.logging import logger @@ -26,7 +27,7 @@ class Config(PluginConfig): Source exclusion globs across all file types. """ - @validator("evm_version") + @field_validator("evm_version") def warn_deprecate(cls, value): if value: logger.warning( @@ -36,7 +37,7 @@ def warn_deprecate(cls, value): return None - @validator("exclude", pre=True) + @field_validator("exclude", mode="before") def validate_exclude(cls, value): return value or [] diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index b6dfd5ce12..1f0f8f9efb 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -16,8 +16,8 @@ ) from ethpm_types import ContractType, HexBytes from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI +from pydantic import Field, field_validator -from ape._pydantic_compat import Field, validator from ape.api import BlockAPI, EcosystemAPI, PluginConfig, ReceiptAPI, TransactionAPI from ape.api.networks import LOCAL_NETWORK_NAME from ape.contracts.base import ContractCall @@ -105,13 +105,10 @@ class NetworkConfig(PluginConfig): base_fee_multiplier: float = 1.0 """A multiplier to apply to a transaction base fee.""" - class Config: - smart_union = True - - @validator("gas_limit", pre=True, allow_reuse=True) + @field_validator("gas_limit", mode="before") def validate_gas_limit(cls, value): if isinstance(value, dict) and "auto" in value: - return AutoGasLimit.parse_obj(value["auto"]) + return AutoGasLimit.model_validate(value["auto"]) elif value in ("auto", "max") or isinstance(value, AutoGasLimit): return value @@ -182,7 +179,7 @@ class Block(BlockAPI): EMPTY_BYTES32, alias="parentHash" ) # NOTE: genesis block has no parent hash - @validator( + @field_validator( "base_fee", "difficulty", "gas_limit", @@ -191,7 +188,7 @@ class Block(BlockAPI): "size", "timestamp", "total_difficulty", - pre=True, + mode="before", ) def validate_ints(cls, value): return to_int(value) if value else 0 @@ -395,7 +392,7 @@ def decode_block(self, data: Dict) -> BlockAPI: if "transactions" in data: data["num_transactions"] = len(data["transactions"]) - return Block.parse_obj(data) + return Block.model_validate(data) def _python_type_for_abi_type(self, abi_type: ABIType) -> Union[Type, Tuple, List]: # NOTE: An array can be an array of tuples, so we start with an array check @@ -442,7 +439,7 @@ def encode_calldata(self, abi: Union[ConstructorABI, MethodABI], *args) -> HexBy def decode_calldata(self, abi: Union[ConstructorABI, MethodABI], calldata: bytes) -> Dict: raw_input_types = [i.canonical_type for i in abi.inputs] - input_types = [parse_type(i.dict()) for i in abi.inputs] + input_types = [parse_type(i.model_dump(mode="json")) for i in abi.inputs] try: raw_input_values = decode(raw_input_types, calldata) @@ -475,7 +472,7 @@ def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> Tuple[Any, ...]: elif not isinstance(vm_return_values, (tuple, list)): vm_return_values = (vm_return_values,) - output_types = [parse_type(o.dict()) for o in abi.outputs] + output_types = [parse_type(o.model_dump(mode="json")) for o in abi.outputs] output_values = [ self.decode_primitive_value(v, t) for v, t in zip(vm_return_values, output_types) ] diff --git a/src/ape_ethereum/multicall/handlers.py b/src/ape_ethereum/multicall/handlers.py index 22f65f3c2a..48fd969061 100644 --- a/src/ape_ethereum/multicall/handlers.py +++ b/src/ape_ethereum/multicall/handlers.py @@ -64,7 +64,7 @@ def contract(self) -> ContractInstance: # else use our backend (with less methods) contract = self.chain_manager.contracts.instance_at( MULTICALL3_ADDRESS, - contract_type=ContractType.parse_obj(MULTICALL3_CONTRACT_TYPE), + contract_type=ContractType.model_validate(MULTICALL3_CONTRACT_TYPE), ) if self.provider.chain_id not in SUPPORTED_CHAINS and contract.code != MULTICALL3_CODE: diff --git a/src/ape_ethereum/transactions.py b/src/ape_ethereum/transactions.py index 3d4c3e59e2..0451f6c617 100644 --- a/src/ape_ethereum/transactions.py +++ b/src/ape_ethereum/transactions.py @@ -11,8 +11,8 @@ from eth_utils import decode_hex, encode_hex, keccak, to_hex, to_int from ethpm_types import ContractType, HexBytes from ethpm_types.abi import EventABI, MethodABI +from pydantic import BaseModel, Field, field_validator, model_validator -from ape._pydantic_compat import BaseModel, Field, root_validator, validator from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent from ape.exceptions import OutOfGasError, SignatureError, TransactionError @@ -56,7 +56,7 @@ def serialize_transaction(self) -> bytes: if not self.signature: raise SignatureError("The transaction is not signed.") - txn_data = self.dict(exclude={"sender"}) + txn_data = self.model_dump(exclude={"sender"}) unsigned_txn = serializable_unsigned_transaction_from_dict(txn_data) signature = (self.signature.v, to_int(self.signature.r), to_int(self.signature.s)) signed_txn = encode_transaction(unsigned_txn, signature) @@ -87,7 +87,7 @@ class StaticFeeTransaction(BaseTransaction): type: int = Field(TransactionType.STATIC.value, exclude=True) max_fee: Optional[int] = Field(None, exclude=True) - @root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") def calculate_read_only_max_fee(cls, values) -> Dict: # NOTE: Work-around, Pydantic doesn't handle calculated fields well. values["max_fee"] = values.get("gas_limit", 0) * values.get("gas_price", 0) @@ -100,12 +100,12 @@ class DynamicFeeTransaction(BaseTransaction): and ``maxPriorityFeePerGas`` fields. """ - max_priority_fee: Optional[int] = Field(None, alias="maxPriorityFeePerGas") - max_fee: Optional[int] = Field(None, alias="maxFeePerGas") + max_priority_fee: Optional[int] = Field(None, alias="maxPriorityFeePerGas") # type: ignore + max_fee: Optional[int] = Field(None, alias="maxFeePerGas") # type: ignore type: int = Field(TransactionType.DYNAMIC.value) access_list: List[AccessList] = Field(default_factory=list, alias="accessList") - @validator("type", allow_reuse=True) + @field_validator("type") def check_type(cls, value): return value.value if isinstance(value, TransactionType) else value @@ -119,7 +119,7 @@ class AccessListTransaction(BaseTransaction): type: int = Field(TransactionType.ACCESS_LIST.value) access_list: List[AccessList] = Field(default_factory=list, alias="accessList") - @validator("type", allow_reuse=True) + @field_validator("type") def check_type(cls, value): return value.value if isinstance(value, TransactionType) else value @@ -173,7 +173,7 @@ def source_traceback(self) -> SourceTraceback: if contract_type := self.contract_type: return SourceTraceback.create(contract_type, self.trace, HexBytes(self.data)) - return SourceTraceback.parse_obj([]) + return SourceTraceback.model_validate([]) def raise_for_status(self): if self.gas_limit is not None and self.ran_out_of_gas: @@ -276,7 +276,7 @@ def get_default_log( return ContractLog( block_hash=self.block.hash, block_number=self.block_number, - event_arguments={"__root__": _log["data"]}, + event_arguments={"root": _log["data"]}, event_name=f"<{name}>", log_index=logs[-1].log_index + 1 if logs else 0, transaction_hash=self.txn_hash, diff --git a/src/ape_geth/__init__.py b/src/ape_geth/__init__.py index 93efbd0009..ccf8d88ceb 100644 --- a/src/ape_geth/__init__.py +++ b/src/ape_geth/__init__.py @@ -12,7 +12,7 @@ def config_class(): @plugins.register(plugins.ProviderPlugin) def providers(): - networks_dict = GethNetworkConfig().dict() + networks_dict = GethNetworkConfig().model_dump(mode="json") networks_dict.pop(LOCAL_NETWORK_NAME) for network_name in networks_dict: yield "ethereum", network_name, GethProvider diff --git a/src/ape_geth/provider.py b/src/ape_geth/provider.py index 2d414c130f..8c46d0772c 100644 --- a/src/ape_geth/provider.py +++ b/src/ape_geth/provider.py @@ -25,6 +25,7 @@ from geth.chain import initialize_chain # type: ignore from geth.process import BaseGethProcess # type: ignore from geth.wrapper import construct_test_chain_kwargs # type: ignore +from pydantic_settings import SettingsConfigDict from requests.exceptions import ConnectionError from web3 import HTTPProvider, Web3 from web3.exceptions import ExtraDataLengthError @@ -35,7 +36,6 @@ from web3.providers.auto import load_provider_from_environment from yarl import URL -from ape._pydantic_compat import Extra from ape.api import ( PluginConfig, SubprocessProvider, @@ -226,9 +226,7 @@ class GethConfig(PluginConfig): ipc_path: Optional[Path] = None data_dir: Optional[Path] = None - class Config: - # For allowing all other EVM-based ecosystem plugins - extra = Extra.allow + model_config = SettingsConfigDict(extra="allow") class GethNotInstalledError(ConnectionError): @@ -260,7 +258,7 @@ def uri(self) -> str: # Use adhoc, scripted value return self.provider_settings["uri"] - config = self.config.dict().get(self.network.ecosystem.name, None) + config = self.config.model_dump(mode="json").get(self.network.ecosystem.name, None) if config is None: return DEFAULT_SETTINGS["uri"] @@ -381,7 +379,7 @@ def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode: if not result: raise ProviderError(f"Failed to get trace for '{txn_hash}'.") - traces = ParityTraceList.parse_obj(result) + traces = ParityTraceList.model_validate(result) evm_call = get_calltree_from_parity_trace(traces) return self._create_call_tree_node(evm_call, txn_hash=txn_hash) @@ -456,7 +454,7 @@ def connect(self): self.start() def start(self, timeout: int = 20): - test_config = self.config_manager.get_config("test").dict() + test_config = self.config_manager.get_config("test").model_dump(mode="json") # Allow configuring a custom executable besides your $PATH geth. if self.geth_config.executable is not None: diff --git a/src/ape_plugins/_cli.py b/src/ape_plugins/_cli.py index c1607fa781..ba70caead7 100644 --- a/src/ape_plugins/_cli.py +++ b/src/ape_plugins/_cli.py @@ -36,7 +36,7 @@ def load_from_file(ctx, file_path: Path) -> List[PluginMetadata]: if file_path.is_file(): config = load_config(file_path) if plugins := config.get("plugins"): - return [PluginMetadata.parse_obj(d) for d in plugins] + return [PluginMetadata.model_validate(d) for d in plugins] ctx.obj.logger.warning(f"No plugins found at '{file_path}'.") return [] diff --git a/src/ape_plugins/utils.py b/src/ape_plugins/utils.py index 9f19d64094..8e9c90d638 100644 --- a/src/ape_plugins/utils.py +++ b/src/ape_plugins/utils.py @@ -4,8 +4,9 @@ from functools import cached_property from typing import Iterator, List, Optional, Sequence, Set, Tuple +from pydantic import model_validator + from ape.__modules__ import __modules__ -from ape._pydantic_compat import root_validator from ape.logging import logger from ape.plugins import clean_plugin_name from ape.utils import BaseInterfaceModel, get_package_version, github_client @@ -86,7 +87,7 @@ class PluginMetadataList(BaseModel): @classmethod def from_package_names(cls, packages: Sequence[str]) -> "PluginMetadataList": - PluginMetadataList.update_forward_refs() + PluginMetadataList.model_rebuild() core = PluginGroup(plugin_type=PluginType.CORE) available = PluginGroup(plugin_type=PluginType.AVAILABLE) installed = PluginGroup(plugin_type=PluginType.INSTALLED) @@ -136,7 +137,7 @@ class PluginMetadata(BaseInterfaceModel): version: Optional[str] = None """The version requested, if there is one.""" - @root_validator(pre=True) + @model_validator(mode="before") def validate_name(cls, values): if "name" not in values: raise ValueError("'name' required.") diff --git a/src/ape_test/__init__.py b/src/ape_test/__init__.py index 53bf1de479..6a4507aecc 100644 --- a/src/ape_test/__init__.py +++ b/src/ape_test/__init__.py @@ -1,7 +1,8 @@ from typing import Dict, List, NewType, Optional, Union +from pydantic import NonNegativeInt + from ape import plugins -from ape._pydantic_compat import NonNegativeInt from ape.api import PluginConfig from ape.api.networks import LOCAL_NETWORK_NAME from ape.utils import DEFAULT_HD_PATH, DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_MNEMONIC diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index c596b3b165..484100a5ee 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -111,7 +111,7 @@ def sign_message(self, msg: SignableMessage) -> Optional[MessageSignature]: def sign_transaction(self, txn: TransactionAPI, **kwargs) -> Optional[TransactionAPI]: # Signs anything that's given to it - signature = EthAccount.sign_transaction(txn.dict(), self.private_key) + signature = EthAccount.sign_transaction(txn.model_dump(mode="json"), self.private_key) txn.signature = TransactionSignature( v=signature.v, r=to_bytes(signature.r), diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index 3cbb5e7db1..65910dd070 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -79,7 +79,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: block_id = kwargs.pop("block_identifier", kwargs.pop("block_id", None)) estimate_gas = self.web3.eth.estimate_gas - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") if txn_dict.get("gas") == "auto": # Remove from dict before estimating txn_dict.pop("gas") @@ -96,7 +96,8 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int: # and then set it back. expected_nonce, actual_nonce = gas_match.groups() txn.nonce = int(expected_nonce) - value = estimate_gas(txn.dict(), block_identifier=block_id) # type: ignore + txn_params: TxParams = cast(TxParams, txn.model_dump(mode="json")) + value = estimate_gas(txn_params, block_identifier=block_id) txn.nonce = int(actual_nonce) return value @@ -118,7 +119,7 @@ def chain_id(self) -> int: except ProviderNotConnectedError: result = self.provider_settings.get("chain_id", self.config.provider.chain_id) - self.cached_chain_id = result + self.cached_chain_id: Optional[int] = result return result @property @@ -177,7 +178,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI: self.chain_manager.history.append(receipt) if receipt.failed: - txn_dict = txn.dict() + txn_dict = txn.model_dump(mode="json") txn_dict["nonce"] += 1 txn_params = cast(TxParams, txn_dict) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 83e64bf7d9..4222686f7e 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -23,7 +23,8 @@ @pytest.fixture(scope="session") def get_contract_type(): def fn(name: str) -> ContractType: - return ContractType.parse_file(CONTRACTS_FOLDER / f"{name}.json") + content = (CONTRACTS_FOLDER / f"{name}.json").read_text() + return ContractType.model_validate_json(content) return fn @@ -256,7 +257,7 @@ def contract_getter(address): / "mainnet" / f"{address}.json" ) - contract = ContractType.parse_file(path) + contract = ContractType.model_validate_json(path.read_text()) chain.contracts._local_contract_types[address] = contract return contract diff --git a/tests/functional/conversion/test_encode_structs.py b/tests/functional/conversion/test_encode_structs.py index cd0373a9fa..2ec863998b 100644 --- a/tests/functional/conversion/test_encode_structs.py +++ b/tests/functional/conversion/test_encode_structs.py @@ -1,13 +1,12 @@ from typing import Dict, Tuple, cast import pytest -from ethpm_types import HexBytes +from ethpm_types import BaseModel, HexBytes from ethpm_types.abi import MethodABI -from ape._pydantic_compat import BaseModel from ape.types import AddressType -ABI = MethodABI.parse_obj( +ABI = MethodABI.model_validate( { "type": "function", "name": "test", diff --git a/tests/functional/test_block.py b/tests/functional/test_block.py index d9398bc7c3..e3b3e581ca 100644 --- a/tests/functional/test_block.py +++ b/tests/functional/test_block.py @@ -7,7 +7,7 @@ def block(chain): def test_block_dict(block): - actual = block.dict() + actual = block.model_dump(mode="json") expected = { "baseFeePerGas": 1000000000, "difficulty": 0, @@ -25,7 +25,7 @@ def test_block_dict(block): def test_block_json(block): - actual = block.json() + actual = block.model_dump_json() expected = ( '{"baseFeePerGas":1000000000,"difficulty":0,"gasLimit":30029122,"gasUsed":0,' f'"hash":"{block.hash.hex()}",' diff --git a/tests/functional/test_block_container.py b/tests/functional/test_block_container.py index 68ba9bc3ff..4ddae4ab6d 100644 --- a/tests/functional/test_block_container.py +++ b/tests/functional/test_block_container.py @@ -61,7 +61,7 @@ def test_block_range_negative_start(chain_that_mined_5): with pytest.raises(ValueError) as err: _ = [b for b in chain_that_mined_5.blocks.range(-1, 3, step=2)] - assert "ensure this value is greater than or equal to 0" in str(err.value) + assert "Input should be greater than or equal to 0" in str(err.value) def test_block_range_out_of_order(chain_that_mined_5): diff --git a/tests/functional/test_config.py b/tests/functional/test_config.py index 0b614c6cf2..4a5aef7eda 100644 --- a/tests/functional/test_config.py +++ b/tests/functional/test_config.py @@ -16,9 +16,9 @@ def test_integer_deployment_addresses(networks): "valid_ecosystems": {"ethereum": networks.ethereum}, "valid_networks": [LOCAL_NETWORK_NAME], } - config = DeploymentConfigCollection(__root__=data) + config = DeploymentConfigCollection(root=data) assert ( - config.__root__["ethereum"]["local"][0]["address"] + config.root["ethereum"]["local"][0]["address"] == "0x0c25212c557d00024b7Ca3df3238683A35541354" ) @@ -35,7 +35,7 @@ def test_bad_value_in_deployments( ecosystem_dict = {e: all_ecosystems[e] for e in ecosystem_names if e in all_ecosystems} data = {**deployments, "valid_ecosystems": ecosystem_dict, "valid_networks": network_names} ape_caplog.assert_last_log_with_retries( - lambda: DeploymentConfigCollection(__root__=data), + lambda: DeploymentConfigCollection(root=data), f"Invalid {err_part}", ) diff --git a/tests/functional/test_contract.py b/tests/functional/test_contract.py index ed22e31397..5513be7fb6 100644 --- a/tests/functional/test_contract.py +++ b/tests/functional/test_contract.py @@ -15,7 +15,8 @@ def test_contract_from_abi(contract_instance): def test_contract_from_abi_list(contract_instance): contract = Contract( - contract_instance.address, abi=[abi.dict() for abi in contract_instance.contract_type.abi] + contract_instance.address, + abi=[abi.model_dump(mode="json") for abi in contract_instance.contract_type.abi], ) assert isinstance(contract, ContractInstance) @@ -26,7 +27,9 @@ def test_contract_from_abi_list(contract_instance): def test_contract_from_json_str(contract_instance): contract = Contract( contract_instance.address, - abi=json.dumps([abi.dict() for abi in contract_instance.contract_type.abi]), + abi=json.dumps( + [abi.model_dump(mode="json") for abi in contract_instance.contract_type.abi] + ), ) assert isinstance(contract, ContractInstance) diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index 48524e6787..88b7e61f2e 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -3,10 +3,9 @@ import pytest from eth_utils import is_checksum_address, to_hex -from ethpm_types import ContractType, HexBytes +from ethpm_types import BaseModel, ContractType, HexBytes from ape import Contract -from ape._pydantic_compat import BaseModel from ape.api import TransactionAPI from ape.contracts import ContractInstance from ape.exceptions import ( @@ -255,6 +254,7 @@ def test_nested_structs(contract_instance, owner, chain): == chain.blocks[-2].hash ) assert isinstance(actual_1.t.b, bytes) + assert getattr(actual_1.t.b, "__orig_class__") is HexBytes assert ( actual_2.t.b == actual_2.t["b"] @@ -263,6 +263,7 @@ def test_nested_structs(contract_instance, owner, chain): == chain.blocks[-2].hash ) assert isinstance(actual_2.t.b, bytes) + assert getattr(actual_2.t.b, "__orig_class__") is HexBytes def test_nested_structs_in_tuples(contract_instance, owner, chain): @@ -754,13 +755,13 @@ def test_value_to_non_payable_fallback_and_no_receive( and you try to send a value, it fails. """ # Hack to set fallback as non-payable. - contract_type_data = vyper_fallback_contract_type.dict() + contract_type_data = vyper_fallback_contract_type.model_dump(mode="json") for abi in contract_type_data["abi"]: if abi.get("type") == "fallback": abi["stateMutability"] = "non-payable" break - new_contract_type = ContractType.parse_obj(contract_type_data) + new_contract_type = ContractType.model_validate(contract_type_data) contract = owner.chain_manager.contracts.instance_at( vyper_fallback_contract.address, contract_type=new_contract_type ) diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 6222ed21cd..9d9db5d6d5 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -91,10 +91,10 @@ def test_block_handles_snake_case_parent_hash(eth_tester_provider, sender, recei # Replace 'parentHash' key with 'parent_hash' latest_block = eth_tester_provider.get_block("latest") - latest_block_dict = eth_tester_provider.get_block("latest").dict() + latest_block_dict = eth_tester_provider.get_block("latest").model_dump(mode="json") latest_block_dict["parent_hash"] = latest_block_dict.pop("parentHash") - redefined_block = Block.parse_obj(latest_block_dict) + redefined_block = Block.model_validate(latest_block_dict) assert redefined_block.parent_hash == latest_block.parent_hash @@ -293,7 +293,7 @@ def test_decode_return_data_non_empty_padding_bytes(ethereum): "000000000000000000000000000000000000000000000000000000000000012696e73756666" "696369656e742066756e64730000000000000000000000000000" ) - abi = MethodABI.parse_obj( + abi = MethodABI.model_validate( { "type": "function", "name": "transfer", diff --git a/tests/functional/test_plugins.py b/tests/functional/test_plugins.py index 37f163b630..142691f93e 100644 --- a/tests/functional/test_plugins.py +++ b/tests/functional/test_plugins.py @@ -63,7 +63,7 @@ def test_names(self, name): assert metadata.package_name == "ape-foo-bar" assert metadata.module_name == "ape_foo_bar" - def test_model_when_version_included_with_name(self): + def test_model_validator_when_version_included_with_name(self): # This allows parsing requirements files easier metadata = PluginMetadata(name="ape-foo-bar==0.5.0") assert metadata.name == "foo-bar" diff --git a/tests/functional/test_project.py b/tests/functional/test_project.py index 4829a64a87..6945102eea 100644 --- a/tests/functional/test_project.py +++ b/tests/functional/test_project.py @@ -71,7 +71,7 @@ def contract_type_1(vyper_contract_type): def existing_source_path(vyper_contract_type, contract_type_0, contracts_folder): source_path = contracts_folder / "NewContract_0.json" source_path.touch() - source_path.write_text(contract_type_0.json()) + source_path.write_text(contract_type_0.model_dump_json()) yield source_path if source_path.is_file(): source_path.unlink() @@ -86,9 +86,9 @@ def manifest_with_non_existent_sources( manifest.contract_types["NewContract_1"] = contract_type_1 # Previous refs shouldn't interfere (bugfix related) manifest.sources["NewContract_0.json"] = Source( - content=contract_type_0.json(), references=["NewContract_1.json"] + content=contract_type_0.model_dump_json(), references=["NewContract_1.json"] ) - manifest.sources["NewContract_1.json"] = Source(content=contract_type_1.json()) + manifest.sources["NewContract_1.json"] = Source(content=contract_type_1.model_dump_json()) return manifest @@ -101,10 +101,10 @@ def project_without_deployments(project): def _make_new_contract(existing_contract: ContractType, name: str): - source_text = existing_contract.json() + source_text = existing_contract.model_dump_json() source_text = source_text.replace(f"{existing_contract.name}.vy", f"{name}.json") source_text = source_text.replace(existing_contract.name or "", name) - return ContractType.parse_raw(source_text) + return ContractType.model_validate_json(source_text) def test_extract_manifest(project_with_dependency_config): @@ -131,10 +131,12 @@ def test_cached_manifest_when_sources_missing( cache_location.touch() name = "NOTEXISTS" source_id = f"{name}.json" - contract_type = ContractType.parse_obj({"contractName": name, "abi": [], "sourceId": source_id}) + contract_type = ContractType.model_validate( + {"contractName": name, "abi": [], "sourceId": source_id} + ) path = ape_project._cache_folder / source_id - path.write_text(contract_type.json()) - cache_location.write_text(manifest_with_non_existent_sources.json()) + path.write_text(contract_type.model_dump_json()) + cache_location.write_text(manifest_with_non_existent_sources.model_dump_json()) manifest = ape_project.cached_manifest @@ -157,7 +159,7 @@ def test_create_manifest_when_file_changed_with_cached_references_that_no_longer ape_project._cache_folder.mkdir(exist_ok=True) cache_location.touch() - cache_location.write_text(manifest_with_non_existent_sources.json()) + cache_location.write_text(manifest_with_non_existent_sources.model_dump_json()) # Change content source_text = existing_source_path.read_text() @@ -240,7 +242,7 @@ def test_track_deployment( expected_uri = f"blockchain://{bip122_chain_id}/block/{expected_block_hash}" expected_name = contract.contract_type.name expected_code = contract.contract_type.runtime_bytecode - actual_from_file = EthPMContractInstance.parse_raw(deployment_path.read_text()) + actual_from_file = EthPMContractInstance.model_validate_json(deployment_path.read_text()) actual_from_class = project_without_deployments.tracked_deployments[expected_uri][name] assert actual_from_file.address == actual_from_class.address == address @@ -277,7 +279,7 @@ def test_track_deployment_from_previously_deployed_contract( expected_uri = f"blockchain://{bip122_chain_id}/block/{expected_block_hash}" expected_name = contract.contract_type.name expected_code = contract.contract_type.runtime_bytecode - actual_from_file = EthPMContractInstance.parse_raw(path.read_text()) + actual_from_file = EthPMContractInstance.model_validate_json(path.read_text()) actual_from_class = project_without_deployments.tracked_deployments[expected_uri][name] assert actual_from_file.address == actual_from_class.address == address assert actual_from_file.contract_type == actual_from_class.contract_type == expected_name @@ -320,7 +322,7 @@ def test_track_deployment_from_unknown_contract_given_txn_hash( contract = Contract(address, txn_hash=txn_hash) project.track_deployment(contract) path = base_deployments_path / f"{contract.contract_type.name}.json" - actual = EthPMContractInstance.parse_raw(path.read_text()) + actual = EthPMContractInstance.model_validate_json(path.read_text()) assert actual.address == address assert actual.contract_type == contract.contract_type.name assert actual.transaction == txn_hash diff --git a/tests/functional/test_query.py b/tests/functional/test_query.py index 865e89c6c3..0c4a77c195 100644 --- a/tests/functional/test_query.py +++ b/tests/functional/test_query.py @@ -77,7 +77,7 @@ class Model(BaseInterfaceModel): def test_column_expansion(): columns = validate_and_expand_columns(["*"], Model) - assert columns == list(Model.__fields__) + assert columns == list(Model.model_fields) def test_column_validation(eth_tester_provider, ape_caplog): diff --git a/tests/functional/test_transaction.py b/tests/functional/test_transaction.py index c83c12b73b..d6f5b99526 100644 --- a/tests/functional/test_transaction.py +++ b/tests/functional/test_transaction.py @@ -60,17 +60,17 @@ def test_txn_hash(owner, eth_tester_provider, ethereum): def test_whitespace_in_transaction_data(): data = b"Should not clip whitespace\t\n" txn_dict = {"data": data} - txn = StaticFeeTransaction.parse_obj(txn_dict) + txn = StaticFeeTransaction.model_validate(txn_dict) assert txn.data == data, "Whitespace should not be removed from data" def test_transaction_dict_excludes_none_values(): txn = StaticFeeTransaction() txn.value = 1000000 - actual = txn.dict() + actual = txn.model_dump(mode="json") assert "value" in actual txn.value = None # type: ignore - actual = txn.dict() + actual = txn.model_dump(mode="json") assert "value" not in actual diff --git a/tests/functional/test_types.py b/tests/functional/test_types.py index 2c0ec7d620..6ebab059eb 100644 --- a/tests/functional/test_types.py +++ b/tests/functional/test_types.py @@ -50,7 +50,7 @@ def log(): def test_contract_log_serialization(log, zero_address): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert obj.contract_address == zero_address assert obj.block_hash == BLOCK_HASH assert obj.block_number == BLOCK_NUMBER @@ -61,7 +61,7 @@ def test_contract_log_serialization(log, zero_address): def test_contract_log_serialization_with_hex_strings_and_non_checksum_addresses(log, zero_address): - data = log.dict() + data = log.model_dump(mode="json") data["log_index"] = to_hex(log.log_index) data["transaction_index"] = to_hex(log.transaction_index) data["block_number"] = to_hex(log.block_number) @@ -79,12 +79,12 @@ def test_contract_log_serialization_with_hex_strings_and_non_checksum_addresses( def test_contract_log_str(log): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert str(obj) == "MyEvent(foo=0 bar=1)" def test_contract_log_repr(log): - obj = ContractLog.parse_obj(log.dict()) + obj = ContractLog.model_validate(log.model_dump(mode="json")) assert repr(obj) == "" @@ -96,7 +96,7 @@ def test_contract_log_access(log): def test_topic_filter_encoding(): - event_abi = EventABI.parse_raw(RAW_EVENT_ABI) + event_abi = EventABI.model_validate_json(RAW_EVENT_ABI) log_filter = LogFilter.from_event( event=event_abi, search_topics={"newVersion": "0x8c44Cc5c0f5CD2f7f17B9Aca85d456df25a61Ae8"} )