Skip to content

Commit

Permalink
fix: negative block number support in ContractLog.range queries (#2388
Browse files Browse the repository at this point in the history
)
  • Loading branch information
antazoey authored Dec 12, 2024
1 parent fbdbe52 commit e92446f
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 37 deletions.
68 changes: 49 additions & 19 deletions src/ape/contracts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,19 +634,19 @@ def query(
# perf: pandas import is really slow. Avoid importing at module level.
import pandas as pd

HEAD = self.chain_manager.blocks.height
if start_block < 0:
start_block = self.chain_manager.blocks.height + start_block
start_block = HEAD + start_block

if stop_block is None:
stop_block = self.chain_manager.blocks.height
stop_block = HEAD

elif stop_block < 0:
stop_block = self.chain_manager.blocks.height + stop_block
stop_block = HEAD + stop_block

elif stop_block > self.chain_manager.blocks.height:
elif stop_block > HEAD:
raise ChainError(
f"'stop={stop_block}' cannot be greater than "
f"the chain length ({self.chain_manager.blocks.height})."
f"'stop={stop_block}' cannot be greater than the chain length ({HEAD})."
)
query: dict = {
"columns": list(ContractLog.__pydantic_fields__) if columns[0] == "*" else columns,
Expand Down Expand Up @@ -692,12 +692,12 @@ def range(
Returns:
Iterator[:class:`~ape.contracts.base.ContractLog`]
"""

if not (contract_address := getattr(self.contract, "address", None)):
return

start_block = None
stop_block = None
HEAD = self.chain_manager.blocks.height # Current block height

if stop is None:
contract = None
Expand All @@ -706,27 +706,57 @@ def range(
except Exception:
pass

if contract:
if creation := contract.creation_metadata:
start_block = creation.block

stop_block = start_or_stop
# Determine the start block from contract creation metadata
if contract and (creation := contract.creation_metadata):
start_block = creation.block

# Handle single parameter usage (like Python's range(stop))
if start_or_stop == 0:
# stop==0 is the same as stop==HEAD
# because of the -1 (turns to negative).
stop_block = HEAD + 1
elif start_or_stop >= 0:
# Given like range(1)
stop_block = min(start_or_stop - 1, HEAD)
else:
# Give like range(-1)
stop_block = HEAD + start_or_stop

elif start_or_stop is not None and stop is not None:
start_block = start_or_stop
stop_block = stop - 1

stop_block = min(stop_block, self.chain_manager.blocks.height)
# Handle cases where both start and stop are provided
if start_or_stop >= 0:
start_block = min(start_or_stop, HEAD)
else:
# Negative start relative to HEAD
adjusted_value = HEAD + start_or_stop + 1
start_block = max(adjusted_value, 0)

if stop == 0:
# stop==0 is the same as stop==HEAD
# because of the -1 (turns to negative).
stop_block = HEAD
elif stop > 0:
# Positive stop, capped to the chain HEAD
stop_block = min(stop - 1, HEAD)
else:
# Negative stop.
adjusted_value = HEAD + stop
stop_block = max(adjusted_value, 0)

# Gather all addresses to query (contract and any extra ones provided)
addresses = list(set([contract_address] + (extra_addresses or [])))

# Construct the event query
contract_event_query = ContractEventQuery(
columns=list(ContractLog.__pydantic_fields__),
columns=list(ContractLog.__pydantic_fields__), # Ensure all necessary columns
contract=addresses,
event=self.abi,
search_topics=search_topics,
start_block=start_block or 0,
stop_block=stop_block,
start_block=start_block or 0, # Default to block 0 if not set
stop_block=stop_block, # None means query to the current HEAD
)

# Execute the query and yield results
yield from self.query_manager.query(contract_event_query) # type: ignore

def from_receipt(self, receipt: "ReceiptAPI") -> list[ContractLog]:
Expand Down
81 changes: 63 additions & 18 deletions tests/functional/test_contract_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@pytest.fixture
def assert_log_values(owner, chain):
def assert_log_values(owner):
def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[int] = None):
assert isinstance(log.b, bytes)
expected_previous_number = number - 1 if previous_number is None else previous_number
Expand All @@ -32,7 +32,7 @@ def _assert_log_values(log: ContractLog, number: int, previous_number: Optional[
return _assert_log_values


def test_contract_logs_from_receipts(owner, contract_instance, assert_log_values):
def test_from_receipts(owner, contract_instance, assert_log_values):
event_type = contract_instance.NumberChange

# Invoke a transaction 3 times that generates 3 logs.
Expand All @@ -55,7 +55,7 @@ def assert_receipt_logs(receipt: "ReceiptAPI", num: int):
assert_receipt_logs(receipt_2, 3)


def test_contract_logs_from_event_type(contract_instance, owner, assert_log_values):
def test_from_event_type(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange
start_num = 6
size = 20
Expand All @@ -76,7 +76,7 @@ def test_contract_logs_from_event_type(contract_instance, owner, assert_log_valu
assert_log_values(log, num)


def test_contract_logs_index_access(contract_instance, owner, assert_log_values):
def test_index_access(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange

contract_instance.setNumber(1, sender=owner)
Expand All @@ -93,7 +93,7 @@ def test_contract_logs_index_access(contract_instance, owner, assert_log_values)
assert event_type[-1] == contract_instance.NumberChange(newNum=3, prevNum=2)


def test_contract_logs_splicing(contract_instance, owner, assert_log_values):
def test_splicing(contract_instance, owner, assert_log_values):
event_type = contract_instance.NumberChange

contract_instance.setNumber(1, sender=owner)
Expand All @@ -113,7 +113,7 @@ def test_contract_logs_splicing(contract_instance, owner, assert_log_values):
assert_log_values(log, 2)


def test_contract_logs_range(chain, contract_instance, owner, assert_log_values):
def test_range(chain, contract_instance, owner, assert_log_values):
contract_instance.setNumber(1, sender=owner)
start = chain.blocks.height
logs = [
Expand All @@ -126,7 +126,7 @@ def test_contract_logs_range(chain, contract_instance, owner, assert_log_values)
assert_log_values(logs[0], 1)


def test_contract_logs_range_by_address(
def test_range_by_address(
mocker, chain, eth_tester_provider, accounts, contract_instance, owner, assert_log_values
):
get_logs_spy = mocker.spy(eth_tester_provider.tester.ethereum_tester, "get_logs")
Expand Down Expand Up @@ -157,29 +157,27 @@ def test_contract_logs_range_by_address(
assert logs == [contract_instance.AddressChange(newAddress=accounts[1])]


def test_contracts_log_multiple_addresses(
def test_range_multiple_addresses(
chain, contract_instance, contract_container, owner, assert_log_values
):
another_instance = contract_container.deploy(0, sender=owner)
start_block = chain.blocks.height
contract_instance.setNumber(1, sender=owner)
another_instance.setNumber(1, sender=owner)

logs = [
log
for log in contract_instance.NumberChange.range(
logs = list(
contract_instance.NumberChange.range(
start_block,
start_block + 100,
search_topics={"newNum": 1},
extra_addresses=[another_instance.address],
)
]
assert len(logs) == 2, "Unexpected number of logs"
)
assert len(logs) == 2, f"Unexpected number of logs: {len(logs)}"
assert logs[0] == contract_instance.NumberChange(newNum=1, prevNum=0)
assert logs[1] == another_instance.NumberChange(newNum=1, prevNum=0)


def test_contract_logs_range_start_and_stop(contract_instance, owner, chain):
def test_range_start_and_stop(contract_instance, owner, chain):
# Create 1 event
contract_instance.setNumber(1, sender=owner)

Expand All @@ -194,16 +192,63 @@ def test_contract_logs_range_start_and_stop(contract_instance, owner, chain):
assert len(logs) == 3, "Unexpected number of logs"


def test_contract_logs_range_only_stop(contract_instance, owner, chain):
# Create 1 event
def test_range_only_stop(contract_instance, owner, chain):
# Create 3 events
start = chain.blocks.height
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
contract_instance.setNumber(3, sender=owner)

stop = start + 100 # Stop can be bigger than height, it doesn't not matter
logs = [log for log in contract_instance.NumberChange.range(stop)]
assert len(logs) >= 3, "Unexpected number of logs"
assert len(logs) >= 3, f"Unexpected number of logs: {len(logs)}"


def test_range_negative_start(contract_instance, owner):
# Create 2 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
logs = [log for log in contract_instance.NumberChange.range(-2, 0)]
assert len(logs) == 2


def test_range_negative_start_and_stop(contract_instance, owner):
# Create 3 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)
contract_instance.setNumber(3, sender=owner)

query_result = [log for log in contract_instance.NumberChange.range(-1, 0)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 3 # Was the last parameter.
query_result = [log for log in contract_instance.NumberChange.range(-2, -1)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 2 # Was the penultimate parameter.
query_result = [log for log in contract_instance.NumberChange.range(-3, -2)]
assert len(query_result) == 1, "Should only be 1"
assert query_result[0].newNum == 1 # Was the penultimate parameter.
logs = [log for log in contract_instance.NumberChange.range(-3, -1)]
assert len(logs) == 2
assert [x.newNum for x in logs] == [1, 2]
logs = [log for log in contract_instance.NumberChange.range(-3, 0)]
assert len(logs) == 3
assert [x.newNum for x in logs] == [1, 2, 3]


def test_range_negative_stop_only(contract_instance, owner):
# Create 2 events
contract_instance.setNumber(1, sender=owner)
contract_instance.setNumber(2, sender=owner)

# Get _all_ logs.
logs = [log for log in contract_instance.NumberChange.range(0)]
assert len(logs) == 2
assert [x.newNum for x in logs] == [1, 2]

# Basically means go from 0 to the second to last
logs = [log for log in contract_instance.NumberChange.range(-1)]
assert len(logs) == 1
assert logs[0].newNum == 1


def test_poll_logs_stop_block_not_in_future(
Expand Down

0 comments on commit e92446f

Please sign in to comment.