From e92446fbcd8b2727810b89451ed9927a5891f267 Mon Sep 17 00:00:00 2001 From: antazoey Date: Thu, 12 Dec 2024 23:24:00 +0700 Subject: [PATCH] fix: negative block number support in `ContractLog.range` queries (#2388) --- src/ape/contracts/base.py | 68 +++++++++++++++------ tests/functional/test_contract_event.py | 81 +++++++++++++++++++------ 2 files changed, 112 insertions(+), 37 deletions(-) diff --git a/src/ape/contracts/base.py b/src/ape/contracts/base.py index 62fe4f7108..7e6067de56 100644 --- a/src/ape/contracts/base.py +++ b/src/ape/contracts/base.py @@ -634,19 +634,19 @@ def query( # perf: pandas import is really slow. Avoid importing at module level. import pandas as pd + HEAD = self.chain_manager.blocks.height if start_block < 0: - start_block = self.chain_manager.blocks.height + start_block + start_block = HEAD + start_block if stop_block is None: - stop_block = self.chain_manager.blocks.height + stop_block = HEAD elif stop_block < 0: - stop_block = self.chain_manager.blocks.height + stop_block + stop_block = HEAD + stop_block - elif stop_block > self.chain_manager.blocks.height: + elif stop_block > HEAD: raise ChainError( - f"'stop={stop_block}' cannot be greater than " - f"the chain length ({self.chain_manager.blocks.height})." + f"'stop={stop_block}' cannot be greater than the chain length ({HEAD})." ) query: dict = { "columns": list(ContractLog.__pydantic_fields__) if columns[0] == "*" else columns, @@ -692,12 +692,12 @@ def range( Returns: Iterator[:class:`~ape.contracts.base.ContractLog`] """ - if not (contract_address := getattr(self.contract, "address", None)): return start_block = None stop_block = None + HEAD = self.chain_manager.blocks.height # Current block height if stop is None: contract = None @@ -706,27 +706,57 @@ def range( except Exception: pass - if contract: - if creation := contract.creation_metadata: - start_block = creation.block - - stop_block = start_or_stop + # Determine the start block from contract creation metadata + if contract and (creation := contract.creation_metadata): + start_block = creation.block + + # Handle single parameter usage (like Python's range(stop)) + if start_or_stop == 0: + # stop==0 is the same as stop==HEAD + # because of the -1 (turns to negative). + stop_block = HEAD + 1 + elif start_or_stop >= 0: + # Given like range(1) + stop_block = min(start_or_stop - 1, HEAD) + else: + # Give like range(-1) + stop_block = HEAD + start_or_stop elif start_or_stop is not None and stop is not None: - start_block = start_or_stop - stop_block = stop - 1 - - stop_block = min(stop_block, self.chain_manager.blocks.height) + # Handle cases where both start and stop are provided + if start_or_stop >= 0: + start_block = min(start_or_stop, HEAD) + else: + # Negative start relative to HEAD + adjusted_value = HEAD + start_or_stop + 1 + start_block = max(adjusted_value, 0) + + if stop == 0: + # stop==0 is the same as stop==HEAD + # because of the -1 (turns to negative). + stop_block = HEAD + elif stop > 0: + # Positive stop, capped to the chain HEAD + stop_block = min(stop - 1, HEAD) + else: + # Negative stop. + adjusted_value = HEAD + stop + stop_block = max(adjusted_value, 0) + # Gather all addresses to query (contract and any extra ones provided) addresses = list(set([contract_address] + (extra_addresses or []))) + + # Construct the event query contract_event_query = ContractEventQuery( - columns=list(ContractLog.__pydantic_fields__), + columns=list(ContractLog.__pydantic_fields__), # Ensure all necessary columns contract=addresses, event=self.abi, search_topics=search_topics, - start_block=start_block or 0, - stop_block=stop_block, + start_block=start_block or 0, # Default to block 0 if not set + stop_block=stop_block, # None means query to the current HEAD ) + + # Execute the query and yield results yield from self.query_manager.query(contract_event_query) # type: ignore def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]: diff --git a/tests/functional/test_contract_event.py b/tests/functional/test_contract_event.py index 8dfa555b7e..7ab5a2c9a1 100644 --- a/tests/functional/test_contract_event.py +++ b/tests/functional/test_contract_event.py @@ -17,7 +17,7 @@ @pytest.fixture -def assert_log_values(owner, chain): +def assert_log_values(owner): def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[int] = None): assert isinstance(log.b, bytes) expected_previous_number = number - 1 if previous_number is None else previous_number @@ -32,7 +32,7 @@ def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[ return _assert_log_values -def test_contract_logs_from_receipts(owner, contract_instance, assert_log_values): +def test_from_receipts(owner, contract_instance, assert_log_values): event_type = contract_instance.NumberChange # Invoke a transaction 3 times that generates 3 logs. @@ -55,7 +55,7 @@ def assert_receipt_logs(receipt: "ReceiptAPI", num: int): assert_receipt_logs(receipt_2, 3) -def test_contract_logs_from_event_type(contract_instance, owner, assert_log_values): +def test_from_event_type(contract_instance, owner, assert_log_values): event_type = contract_instance.NumberChange start_num = 6 size = 20 @@ -76,7 +76,7 @@ def test_contract_logs_from_event_type(contract_instance, owner, assert_log_valu assert_log_values(log, num) -def test_contract_logs_index_access(contract_instance, owner, assert_log_values): +def test_index_access(contract_instance, owner, assert_log_values): event_type = contract_instance.NumberChange contract_instance.setNumber(1, sender=owner) @@ -93,7 +93,7 @@ def test_contract_logs_index_access(contract_instance, owner, assert_log_values) assert event_type[-1] == contract_instance.NumberChange(newNum=3, prevNum=2) -def test_contract_logs_splicing(contract_instance, owner, assert_log_values): +def test_splicing(contract_instance, owner, assert_log_values): event_type = contract_instance.NumberChange contract_instance.setNumber(1, sender=owner) @@ -113,7 +113,7 @@ def test_contract_logs_splicing(contract_instance, owner, assert_log_values): assert_log_values(log, 2) -def test_contract_logs_range(chain, contract_instance, owner, assert_log_values): +def test_range(chain, contract_instance, owner, assert_log_values): contract_instance.setNumber(1, sender=owner) start = chain.blocks.height logs = [ @@ -126,7 +126,7 @@ def test_contract_logs_range(chain, contract_instance, owner, assert_log_values) assert_log_values(logs[0], 1) -def test_contract_logs_range_by_address( +def test_range_by_address( mocker, chain, eth_tester_provider, accounts, contract_instance, owner, assert_log_values ): get_logs_spy = mocker.spy(eth_tester_provider.tester.ethereum_tester, "get_logs") @@ -157,29 +157,27 @@ def test_contract_logs_range_by_address( assert logs == [contract_instance.AddressChange(newAddress=accounts[1])] -def test_contracts_log_multiple_addresses( +def test_range_multiple_addresses( chain, contract_instance, contract_container, owner, assert_log_values ): another_instance = contract_container.deploy(0, sender=owner) start_block = chain.blocks.height contract_instance.setNumber(1, sender=owner) another_instance.setNumber(1, sender=owner) - - logs = [ - log - for log in contract_instance.NumberChange.range( + logs = list( + contract_instance.NumberChange.range( start_block, start_block + 100, search_topics={"newNum": 1}, extra_addresses=[another_instance.address], ) - ] - assert len(logs) == 2, "Unexpected number of logs" + ) + assert len(logs) == 2, f"Unexpected number of logs: {len(logs)}" assert logs[0] == contract_instance.NumberChange(newNum=1, prevNum=0) assert logs[1] == another_instance.NumberChange(newNum=1, prevNum=0) -def test_contract_logs_range_start_and_stop(contract_instance, owner, chain): +def test_range_start_and_stop(contract_instance, owner, chain): # Create 1 event contract_instance.setNumber(1, sender=owner) @@ -194,8 +192,8 @@ def test_contract_logs_range_start_and_stop(contract_instance, owner, chain): assert len(logs) == 3, "Unexpected number of logs" -def test_contract_logs_range_only_stop(contract_instance, owner, chain): - # Create 1 event +def test_range_only_stop(contract_instance, owner, chain): + # Create 3 events start = chain.blocks.height contract_instance.setNumber(1, sender=owner) contract_instance.setNumber(2, sender=owner) @@ -203,7 +201,54 @@ def test_contract_logs_range_only_stop(contract_instance, owner, chain): stop = start + 100 # Stop can be bigger than height, it doesn't not matter logs = [log for log in contract_instance.NumberChange.range(stop)] - assert len(logs) >= 3, "Unexpected number of logs" + assert len(logs) >= 3, f"Unexpected number of logs: {len(logs)}" + + +def test_range_negative_start(contract_instance, owner): + # Create 2 events + contract_instance.setNumber(1, sender=owner) + contract_instance.setNumber(2, sender=owner) + logs = [log for log in contract_instance.NumberChange.range(-2, 0)] + assert len(logs) == 2 + + +def test_range_negative_start_and_stop(contract_instance, owner): + # Create 3 events + contract_instance.setNumber(1, sender=owner) + contract_instance.setNumber(2, sender=owner) + contract_instance.setNumber(3, sender=owner) + + query_result = [log for log in contract_instance.NumberChange.range(-1, 0)] + assert len(query_result) == 1, "Should only be 1" + assert query_result[0].newNum == 3 # Was the last parameter. + query_result = [log for log in contract_instance.NumberChange.range(-2, -1)] + assert len(query_result) == 1, "Should only be 1" + assert query_result[0].newNum == 2 # Was the penultimate parameter. + query_result = [log for log in contract_instance.NumberChange.range(-3, -2)] + assert len(query_result) == 1, "Should only be 1" + assert query_result[0].newNum == 1 # Was the penultimate parameter. + logs = [log for log in contract_instance.NumberChange.range(-3, -1)] + assert len(logs) == 2 + assert [x.newNum for x in logs] == [1, 2] + logs = [log for log in contract_instance.NumberChange.range(-3, 0)] + assert len(logs) == 3 + assert [x.newNum for x in logs] == [1, 2, 3] + + +def test_range_negative_stop_only(contract_instance, owner): + # Create 2 events + contract_instance.setNumber(1, sender=owner) + contract_instance.setNumber(2, sender=owner) + + # Get _all_ logs. + logs = [log for log in contract_instance.NumberChange.range(0)] + assert len(logs) == 2 + assert [x.newNum for x in logs] == [1, 2] + + # Basically means go from 0 to the second to last + logs = [log for log in contract_instance.NumberChange.range(-1)] + assert len(logs) == 1 + assert logs[0].newNum == 1 def test_poll_logs_stop_block_not_in_future(