Skip to content

Commit

Permalink
fix: unreleased regression from using network_option with other optio…
Browse files Browse the repository at this point in the history
…ns [APE-1612] (#1772)
  • Loading branch information
antazoey authored Dec 13, 2023
1 parent d859829 commit d1aa1a1
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 34 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"packaging>=23.0,<24",
"pandas>=1.3.0,<2",
"pluggy>=1.3,<2",
"pydantic>=2.5.0,<3",
"pydantic>=2.5.2,<3",
"pydantic-settings>=2.0.3,<3",
"PyGithub>=1.59,<2",
"pytest>=6.0,<8.0",
Expand Down
2 changes: 1 addition & 1 deletion src/ape/api/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def address(self) -> AddressType:
def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]:
raise NotImplementedError("This account cannot sign messages")

def sign_transaction(self, txn: TransactionAPI, **kwargs) -> Optional[TransactionAPI]:
def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]:
# Returns input transaction unsigned (since it doesn't have access to the key)
return txn

Expand Down
44 changes: 29 additions & 15 deletions src/ape/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,19 @@ def network_option(
"""

def decorator(f):
# These are the available network object names you can request.
network_object_names = ("ecosystem", "network", "provider")

# All kwargs in the defined @click.commmand().
command_signature = inspect.signature(f)
command_kwargs = [x.name for x in command_signature.parameters.values()]

# Any combinaiton of ["ecosystem", "network", "provider"]
requested_network_objects = [x for x in command_kwargs if x in network_object_names]

# When using network_option, handle parsing now so we can pass to
# callback outside of command context.
user_cb = kwargs.pop("callback", None)
signature = inspect.signature(f)
network_args = ("ecosystem", "network", "provider")
requested_data = [x.name for x in signature.parameters.values() if x.name in network_args]
user_callback = kwargs.pop("callback", None)

def callback(ctx, param, value):
is_legacy = param.type.base_type is str
Expand Down Expand Up @@ -224,23 +231,32 @@ def callback(ctx, param, value):
}

# Set the actual values.
for item in requested_data:
for item in requested_network_objects:
instance = choice_classes[item]
ctx.params[item] = instance

# else: provider is None, meaning not connected intentionally.

return value if user_cb is None else user_cb(ctx, param, value)
return value if user_callback is None else user_callback(ctx, param, value)

# Prevent argument errors but initializing callback to use None placeholders.
partial_kwargs: Dict = {}
for arg_type in network_args:
if arg_type in requested_data:
for arg_type in network_object_names:
if arg_type in requested_network_objects:
partial_kwargs[arg_type] = None

# Set this property for click framework to function properly.
partial_f = partial(f, **partial_kwargs)
partial_f.__name__ = f.__name__ # type: ignore[attr-defined]
if partial_kwargs:
wrapped_f = partial(f, **partial_kwargs)

# NOTE: The following is needed for click internals.
wrapped_f.__name__ = f.__name__ # type: ignore[attr-defined]

# Add other click parameters.
if hasattr(f, "__click_params__"):
wrapped_f.__click_params__ = f.__click_params__ # type: ignore[attr-defined]
else:
# No network kwargs are used. No need for partial wrapper.
wrapped_f = f

# Use NetworkChoice option. Raises:
kwargs["type"] = None
Expand All @@ -254,17 +270,15 @@ def callback(ctx, param, value):
kwargs["callback"] = callback

# Create the actual option.
option = click.option(
return click.option(
default=default,
ecosystem=ecosystem,
network=network,
provider=provider,
required=required,
cls=NetworkOption,
**kwargs,
)(partial_f)

return option
)(wrapped_f)

return decorator

Expand Down
4 changes: 1 addition & 3 deletions src/ape/types/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ class _AddressValidator(_Address, ManagerAccessMixin):
@classmethod
def __eth_pydantic_validate__(cls, value: Any, info: Optional[ValidationInfo] = None) -> str:
if type(value) in (list, tuple):
return cls.conversion_manager.convert(
value, List[AddressType]
) # type: ignore[valid-type]
return cls.conversion_manager.convert(value, List[AddressType])

return (
cls.conversion_manager.convert(value, AddressType)
Expand Down
4 changes: 1 addition & 3 deletions src/ape_accounts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def delete(self):
self.keyfile_path.unlink()

def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]:
user_approves = False

if isinstance(msg, str):
user_approves = self.__autosign or click.confirm(f"Message: {msg}\n\nSign: ")
msg = encode_defunct(text=msg)
Expand Down Expand Up @@ -211,7 +209,7 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]
s=to_bytes(signed_msg.s),
)

def sign_transaction(self, txn: TransactionAPI, **kwargs) -> Optional[TransactionAPI]:
def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]:
user_approves = self.__autosign or click.confirm(f"{txn}\n\nSign: ")
if not user_approves:
return None
Expand Down
10 changes: 4 additions & 6 deletions src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,7 @@ def assert_chain_activity():
return

# Set the last action, used for checking timeouts and re-orgs.
last = YieldAction(
number=block.number, hash=block.hash, time=time.time()
) # type: ignore
last = YieldAction(number=block.number, hash=block.hash, time=time.time())

def poll_logs(
self,
Expand Down Expand Up @@ -1115,7 +1113,7 @@ def _ots_api_level(self) -> Optional[int]:

def _set_web3(self):
# Clear cached version when connecting to another URI.
self._client_version = None # type: ignore
self._client_version = None
self._web3 = _create_web3(self.uri, ipc_path=self.ipc_path)

def _complete_connect(self):
Expand Down Expand Up @@ -1166,8 +1164,8 @@ def _complete_connect(self):

def disconnect(self):
self.can_use_parity_traces = None
self._web3 = None # type: ignore
self._client_version = None # type: ignore
self._web3 = None
self._client_version = None

def get_transaction_trace(self, txn_hash: str) -> Iterator[TraceFrame]:
frames = self._stream_request(
Expand Down
2 changes: 1 addition & 1 deletion src/ape_test/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def sign_message(self, msg: Any, **signer_options) -> Optional[MessageSignature]
)
return None

def sign_transaction(self, txn: TransactionAPI, **kwargs) -> Optional[TransactionAPI]:
def sign_transaction(self, txn: TransactionAPI, **signer_options) -> Optional[TransactionAPI]:
# Signs anything that's given to it
signature = EthAccount.sign_transaction(txn.model_dump(mode="json"), self.private_key)
txn.signature = TransactionSignature(
Expand Down
6 changes: 3 additions & 3 deletions src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def connect(self):

def disconnect(self):
# NOTE: This type ignore seems like a bug in pydantic.
self._web3 = None # type: ignore
self._evm_backend = None # type: ignore
self._web3 = None
self._evm_backend = None
self.provider_settings = {}

def update_settings(self, new_settings: Dict):
self._cached_chain_id = None # type: ignore[assignment]
self._cached_chain_id = None
self.provider_settings = {**self.provider_settings, **new_settings}
self.disconnect()
self.connect()
Expand Down
102 changes: 101 additions & 1 deletion tests/functional/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ape.logging import logger

OUTPUT_FORMAT = "__TEST__{0}:{1}:{2}_"
OTHER_OPTION_VALUE = "TEST_OTHER_OPTION"
other_option = click.option("--other", default=OTHER_OPTION_VALUE)


@pytest.fixture
Expand Down Expand Up @@ -179,6 +181,41 @@ def test_network_option_unknown(runner, network_cmd):
assert result.exit_code != 0


def test_network_option_with_other_option(runner):
"""
To prove can use the `@network_option` with other options
in the same command (was issue during production where could not!).
"""

# Scenario: Using network_option but not using the value in the command callback.
# (Potentially handling independently).
@click.command()
@network_option()
@other_option
def solo_option(other):
click.echo(other)

# Scenario: Using the network option with another option.
# This use-case is way more common than the one above.
@click.command()
@network_option()
@other_option
def with_net(network, other):
click.echo(network.name)
click.echo(other)

def run(cmd, fail_msg=None):
res = runner.invoke(cmd, [], catch_exceptions=False)
fail_msg = f"{fail_msg}\n{res.output}" if fail_msg else res.output
assert res.exit_code == 0, fail_msg
assert OTHER_OPTION_VALUE in res.output, fail_msg
return res

run(solo_option, fail_msg="Failed when used without network kwargs")
result = run(with_net, fail_msg="Failed when used with network kwargs")
assert "local" in result.output


@pytest.mark.parametrize(
"network_input",
(
Expand Down Expand Up @@ -454,6 +491,64 @@ def cmd(ecosystem, network, provider):
assert "ethereum:local:test" in result.output, result.output


def test_connected_provider_command_use_custom_options(runner):
"""
Ensure custom options work when using `ConnectedProviderCommand`.
(There was an issue during development where we could not).
"""

# Scenario: Custom option and using network object.
@click.command(cls=ConnectedProviderCommand)
@other_option
def use_net(network, other):
click.echo(network.name)
click.echo(other)

# Scenario: Only using custom option.
@click.command(cls=ConnectedProviderCommand)
@other_option
def solo_other(other):
click.echo(other)

# Scenario: Option explicit (shouldn't matter)
@click.command(cls=ConnectedProviderCommand)
@network_option()
@other_option
def explicit_option(other):
click.echo(other)

@click.command(cls=ConnectedProviderCommand)
@network_option()
@click.argument("other_arg")
@other_option
def with_arg(other_arg, other, provider):
click.echo(other)
click.echo(provider.name)
click.echo(other_arg)

spec = ("--network", "ethereum:local:test")

def run(cmd, extra_args=None):
arguments = [*spec, *(extra_args or [])]
res = runner.invoke(cmd, arguments, catch_exceptions=False)
assert res.exit_code == 0, res.output
assert OTHER_OPTION_VALUE in res.output
return res

result = run(use_net)
assert "local" in result.output, result.output # Echos network object

result = run(solo_other)
assert "local" not in result.output, result.output

run(explicit_option)

argument = "_extra_"
result = run(with_arg, extra_args=[argument])
assert "test" in result.output
assert argument in result.output


# TODO: Delete for 0.8.
def test_deprecated_network_bound_command(runner):
with pytest.warns(
Expand All @@ -463,9 +558,14 @@ def test_deprecated_network_bound_command(runner):

@click.command(cls=NetworkBoundCommand)
@network_option()
def cmd(network):
# NOTE: Must also make sure can use other options with this combo!
# (was issue where could not).
@click.option("--other", default=OTHER_OPTION_VALUE)
def cmd(network, other):
click.echo(network)
click.echo(other)

result = runner.invoke(cmd, ["--network", "ethereum:local:test"], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "ethereum:local:test" in result.output, result.output
assert OTHER_OPTION_VALUE in result.output

0 comments on commit d1aa1a1

Please sign in to comment.