Skip to content

Commit

Permalink
feat: pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Oct 2, 2023
1 parent 536d591 commit 160a2ca
Show file tree
Hide file tree
Showing 47 changed files with 311 additions and 333 deletions.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@
"packaging>=23.0,<24",
"pandas>=1.3.0,<2",
"pluggy>=1.3,<2",
"pydantic>=1.10.8,<3",
"pydantic>=2.4.0,<3",
"pydantic-settings>=2.0.3,<3",
"PyGithub>=1.59,<2",
"pytest>=6.0,<8.0",
"python-dateutil>=2.8.2,<3",
Expand All @@ -124,8 +125,8 @@
"web3[tester]>=6.7.0,<7",
# ** Dependencies maintained by ApeWorX **
"eip712>=0.2.1,<0.3",
"ethpm-types>=0.5.6,<0.6",
"evm-trace>=0.1.0a23",
"ethpm-types>=0.6.0,<0.7",
"evm-trace>=0.1.0a26",
],
entry_points={
"console_scripts": ["ape=ape._cli:cli"],
Expand Down
2 changes: 1 addition & 1 deletion src/ape/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def display_config(ctx, param, value):
from ape import project

click.echo("# Current configuration")
click.echo(yaml.dump(project.config_manager.dict()))
click.echo(yaml.dump(project.config_manager.model_dump(mode="json")))

ctx.exit() # NOTE: Must exit to bypass running ApeCLI

Expand Down
47 changes: 0 additions & 47 deletions src/ape/_pydantic_compat.py

This file was deleted.

13 changes: 4 additions & 9 deletions src/ape/api/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
from typing import Any, Dict, Optional, TypeVar

from ape._pydantic_compat import BaseModel, BaseSettings
from pydantic import ConfigDict
from pydantic_settings import BaseSettings

T = TypeVar("T")

Expand All @@ -14,10 +15,6 @@ class ConfigEnum(str, Enum):
"""


class ConfigDict(BaseModel):
__root__: dict = {}


class PluginConfig(BaseSettings):
"""
A base plugin configuration class. Each plugin that includes
Expand All @@ -26,7 +23,7 @@ class PluginConfig(BaseSettings):

@classmethod
def from_overrides(cls, overrides: Dict) -> "PluginConfig":
default_values = cls().dict()
default_values = cls().model_dump()

def update(root: Dict, value_map: Dict):
for key, val in value_map.items():
Expand Down Expand Up @@ -54,9 +51,7 @@ def get(self, key: str, default: Optional[T] = None) -> T:
return self.__dict__.get(key, default)


class GenericConfig(PluginConfig):
class GenericConfig(ConfigDict):
"""
The default class used when no specialized class is used.
"""

__root__: dict = {}
6 changes: 2 additions & 4 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
serializable_unsigned_transaction_from_dict,
)
from eth_utils import keccak, to_int
from ethpm_types import ContractType, HexBytes
from ethpm_types import BaseModel, ContractType, HexBytes
from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI

from ape._pydantic_compat import BaseModel
from ape.exceptions import (
NetworkError,
NetworkMismatchError,
Expand Down Expand Up @@ -139,8 +138,7 @@ def serialize_transaction(self, transaction: "TransactionAPI") -> bytes:
if not self.signature:
raise SignatureError("The transaction is not signed.")

txn_data = self.dict(exclude={"sender"})

txn_data = self.model_dump(exclude={"sender"})
unsigned_txn = serializable_unsigned_transaction_from_dict(txn_data)
signature = (
self.signature.v,
Expand Down
19 changes: 10 additions & 9 deletions src/ape/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from ethpm_types import Checksum, ContractType, PackageManifest, Source
from ethpm_types.manifest import PackageName
from ethpm_types.source import Content
from ethpm_types.utils import Algorithm, AnyUrl, compute_checksum
from ethpm_types.utils import Algorithm, compute_checksum
from packaging.version import InvalidVersion, Version
from pydantic import AnyUrl, ValidationError

from ape._pydantic_compat import ValidationError
from ape.logging import logger
from ape.utils import (
BaseInterfaceModel,
Expand Down Expand Up @@ -111,7 +111,7 @@ def cached_manifest(self) -> Optional[PackageManifest]:
continue

path = self._cache_folder / f"{contract_type.name}.json"
path.write_text(contract_type.json())
path.write_text(contract_type.model_dump_json())

# Rely on individual cache files.
self._contracts = manifest.contract_types
Expand Down Expand Up @@ -144,9 +144,10 @@ def contracts(self) -> Dict[str, ContractType]:
continue

contract_name = p.stem
contract_type = ContractType.parse_file(p)
if contract_type.name is None:
contract_type.name = contract_name
contract_type = ContractType.model_validate_json(p.read_text())
contract_type.name = (
contract_name if contract_type.name is None else contract_type.name
)

contracts[contract_type.name] = contract_type
self._contracts = contracts
Expand Down Expand Up @@ -215,7 +216,7 @@ def _create_source_dict(
hash=compute_checksum(source_path.read_bytes()),
),
urls=[],
content=Content(__root__={i + 1: x for i, x in enumerate(text.splitlines())}),
content=Content(root={i + 1: x for i, x in enumerate(text.splitlines())}),
imports=source_imports.get(key, []),
references=source_references.get(key, []),
)
Expand Down Expand Up @@ -431,7 +432,7 @@ def _get_sources(self, project: ProjectAPI) -> List[Path]:
def _write_manifest_to_cache(self, manifest: PackageManifest):
self._target_manifest_cache_file.unlink(missing_ok=True)
self._target_manifest_cache_file.parent.mkdir(exist_ok=True, parents=True)
self._target_manifest_cache_file.write_text(manifest.json())
self._target_manifest_cache_file.write_text(manifest.model_dump_json())
self._cached_manifest = manifest


Expand All @@ -440,7 +441,7 @@ def _load_manifest_from_file(file_path: Path) -> Optional[PackageManifest]:
return None

try:
return PackageManifest.parse_file(file_path)
return PackageManifest.model_validate_json(file_path.read_text())
except ValidationError as err:
logger.warning(f"Existing manifest file '{file_path}' corrupted. Re-building.")
logger.debug(str(err))
Expand Down
35 changes: 15 additions & 20 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from ethpm_types import HexBytes
from evm_trace import CallTreeNode as EvmCallTreeNode
from evm_trace import TraceFrame as EvmTraceFrame
from pydantic import Field, model_validator
from web3 import Web3
from web3.exceptions import ContractLogicError as Web3ContractLogicError
from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound
from web3.types import RPCEndpoint, TxParams

from ape._pydantic_compat import Field, root_validator, validator
from ape.api.config import PluginConfig
from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI
from ape.api.query import BlockTransactionQuery
Expand Down Expand Up @@ -95,20 +95,12 @@ class BlockAPI(BaseInterfaceModel):
def datetime(self) -> datetime.datetime:
return datetime.datetime.fromtimestamp(self.timestamp, tz=datetime.timezone.utc)

@root_validator(pre=True)
@model_validator(mode="before")
def convert_parent_hash(cls, data):
parent_hash = data.get("parent_hash", data.get("parentHash")) or EMPTY_BYTES32
data["parentHash"] = parent_hash
return data

@validator("hash", "parent_hash", pre=True)
def validate_hexbytes(cls, value):
# NOTE: pydantic treats these values as bytes and throws an error
if value and not isinstance(value, bytes):
return HexBytes(value)

return value

@cached_property
def transactions(self) -> List[TransactionAPI]:
query = BlockTransactionQuery(columns=["*"], block_id=self.hash)
Expand Down Expand Up @@ -867,7 +859,7 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int:
return the block maximum gas limit.
"""

txn_dict = txn.dict()
txn_dict = txn.model_dump(mode="json")

# Force the use of hex values to support a wider range of nodes.
if isinstance(txn_dict.get("type"), int):
Expand Down Expand Up @@ -1149,7 +1141,7 @@ def _eth_call(self, arguments: List) -> bytes:
return HexBytes(result)

def _prepare_call(self, txn: TransactionAPI, **kwargs) -> List:
txn_dict = txn.dict()
txn_dict = txn.model_dump(mode="json")
fields_to_convert = ("data", "chainId", "value")
for field in fields_to_convert:
value = txn_dict.get(field)
Expand Down Expand Up @@ -1212,7 +1204,9 @@ def get_receipt(
except TimeExhausted as err:
raise TransactionNotFoundError(txn_hash, error_messsage=str(err)) from err

network_config: Dict = self.network.config.dict().get(self.network.name, {})
network_config: Dict = self.network.config.model_dump(mode="json").get(
self.network.name, {}
)
max_retries = network_config.get("max_get_transaction_retries", DEFAULT_MAX_RETRIES_TX)
txn = {}
for attempt in range(max_retries):
Expand Down Expand Up @@ -1382,10 +1376,10 @@ def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]:

def fetch_log_page(block_range):
start, stop = block_range
page_filter = log_filter.copy(update=dict(start_block=start, stop_block=stop))
page_filter = log_filter.model_copy(update=dict(start_block=start, stop_block=stop))
# eth-tester expects a different format, let web3 handle the conversions for it
raw = "EthereumTester" not in self.client_version
logs = self._get_logs(page_filter.dict(), raw)
logs = self._get_logs(page_filter.model_dump(mode="json"), raw)
return self.network.ecosystem.decode_logs(logs, *log_filter.events)

with ThreadPoolExecutor(self.concurrency) as pool:
Expand Down Expand Up @@ -1455,7 +1449,8 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI:
if txn.sender not in self.web3.eth.accounts:
self.chain_manager.provider.unlock_account(txn.sender)

txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn.dict()))
txn_data = cast(TxParams, txn.model_dump(mode="json"))
txn_hash = self.web3.eth.send_transaction(cast(TxParams, txn_data))

except (ValueError, Web3ContractLogicError) as err:
vm_err = self.get_virtual_machine_error(err, txn=txn)
Expand All @@ -1474,7 +1469,7 @@ def send_transaction(self, txn: TransactionAPI) -> ReceiptAPI:
self.chain_manager.history.append(receipt)

if receipt.failed:
txn_dict = receipt.transaction.dict()
txn_dict = receipt.transaction.model_dump(mode="json")
txn_params = cast(TxParams, txn_dict)

# Replay txn to get revert reason
Expand Down Expand Up @@ -1508,7 +1503,7 @@ def _create_call_tree_node(
inputs=evm_call.calldata if "CREATE" in call_type else evm_call.calldata[4:].hex(),
method_id=evm_call.calldata[:4].hex(),
outputs=evm_call.returndata.hex(),
raw=evm_call.dict(),
raw=evm_call.model_dump(mode="json"),
txn_hash=txn_hash,
)

Expand All @@ -1531,7 +1526,7 @@ def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame:
gas_cost=evm_frame.gas_cost,
depth=evm_frame.depth,
contract_address=address,
raw=evm_frame.dict(),
raw=evm_frame.model_dump(mode="json"),
)

def _make_request(self, endpoint: str, parameters: List) -> Any:
Expand Down Expand Up @@ -1755,7 +1750,7 @@ def disconnect(self):
Subclasses override this method to do provider-specific disconnection tasks.
"""

self.cached_chain_id = None
self.cached_chain_id = None # type: ignore
if self.process:
self.stop()

Expand Down
10 changes: 5 additions & 5 deletions src/ape/api/query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import lru_cache
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Type, Union

from ethpm_types.abi import EventABI, MethodABI
from ethpm_types.abi import BaseModel, EventABI, MethodABI
from pydantic import NonNegativeInt, PositiveInt, model_validator

from ape._pydantic_compat import BaseModel, NonNegativeInt, PositiveInt, root_validator
from ape.api.transactions import ReceiptAPI, TransactionAPI
from ape.logging import logger
from ape.types import AddressType
Expand All @@ -22,7 +22,7 @@
# TODO: Replace with `functools.cache` when Py3.8 dropped
@lru_cache(maxsize=None)
def _basic_columns(Model: Type[BaseInterfaceModel]) -> Set[str]:
columns = set(Model.__fields__)
columns = set(Model.model_fields)

# TODO: Remove once `ReceiptAPI` fields cleaned up for better processing
if Model == ReceiptAPI:
Expand Down Expand Up @@ -104,7 +104,7 @@ class _BaseBlockQuery(_BaseQuery):
stop_block: NonNegativeInt
step: PositiveInt = 1

@root_validator(pre=True)
@model_validator(mode="before")
def check_start_block_before_stop_block(cls, values):
if values["stop_block"] < values["start_block"]:
raise ValueError(
Expand Down Expand Up @@ -141,7 +141,7 @@ class AccountTransactionQuery(_BaseQuery):
start_nonce: NonNegativeInt = 0
stop_nonce: NonNegativeInt

@root_validator(pre=True)
@model_validator(mode="before")
def check_start_nonce_before_stop_nonce(cls, values: Dict) -> Dict:
if values["stop_nonce"] < values["start_nonce"]:
raise ValueError(
Expand Down
Loading

0 comments on commit 160a2ca

Please sign in to comment.