Skip to content

Commit

Permalink
refactor: merge 638-appeals-undetermined-transactions into 662-consen…
Browse files Browse the repository at this point in the history
…sus-appeal-success-should-rollback-all-newer-transactions-to-pending
  • Loading branch information
kstroobants committed Dec 20, 2024
2 parents 9a63a1e + 8357219 commit c2d18ca
Show file tree
Hide file tree
Showing 25 changed files with 508 additions and 138 deletions.
200 changes: 149 additions & 51 deletions backend/consensus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,11 @@ async def _run_consensus(self):
transaction: Transaction = await queue.get()
with self.get_session() as session:

async def exec_transaction_with_session_handling():
async def exec_transaction_with_session_handling(
session: Session,
transaction: Transaction,
queue_address: str,
):
await self.exec_transaction(
transaction,
TransactionsProcessor(session),
Expand All @@ -228,7 +232,11 @@ async def exec_transaction_with_session_handling():
False
)

tg.create_task(exec_transaction_with_session_handling())
tg.create_task(
exec_transaction_with_session_handling(
session, transaction, queue_address
)
)

except Exception as e:
print("Error running consensus", e)
Expand Down Expand Up @@ -304,7 +312,7 @@ async def exec_transaction(
)

# Begin state transitions starting from PendingState
state = PendingState()
state = PendingState(called_from_pending_queue=True)
while True:
next_state = await state.handle(context)
if next_state is None:
Expand Down Expand Up @@ -505,15 +513,36 @@ async def _appeal_window(self):
transactions_processor.set_transaction_appeal(
transaction.hash, False
)
transaction.appeal_undetermined = True
transaction.appealed = False

# Set the status to PENDING, transaction will be picked up by _crawl_snapshot
ConsensusAlgorithm.dispatch_transaction_status_update(
transactions_processor,
transaction.hash,
TransactionStatus.PENDING,
self.msg_handler,
)

# Create a transaction context for the appeal process
context = TransactionContext(
transaction=transaction,
transactions_processor=transactions_processor,
snapshot=chain_snapshot,
accounts_manager=AccountsManager(session),
contract_snapshot_factory=lambda contract_address: contract_snapshot_factory(
contract_address, session, transaction
),
node_factory=node_factory,
msg_handler=self.msg_handler,
)

# Begin state transitions starting from PendingState
state = PendingState(called_from_pending_queue=False)
while True:
next_state = await state.handle(context)
if next_state is None:
break
state = next_state
session.commit()

else:
Expand Down Expand Up @@ -625,7 +654,28 @@ def get_extra_validators(
Get extra validators for the appeal process according to the following formula:
- when appeal_failed = 0, add n + 2 validators
- when appeal_failed > 0, add (2 * appeal_failed * n + 1) + 2 validators
Nota that for appeal_failed > 0, the set contains the old validators from the previous appeal round and new validators.
Note that for appeal_failed > 0, the returned set contains the old validators
from the previous appeal round and new validators.
Selection of the extra validators:
appeal_failed | PendingState | Reused validators | Extra selected | Total
| validators | from the previous | validators for the | validators
| | appeal round | appeal |
----------------------------------------------------------------------------------
0 | n | 0 | n+2 | 2n+2
1 | n | n+2 | n+1 | 3n+3
2 | n | 2n+3 | 2n | 5n+3
3 | n | 4n+3 | 2n | 7n+3
└───────┬──────┘ └─────────┬────────┘
│ |
Validators after the ◄────────┘ └──► Validators during the appeal
appeal. This equals for appeal_failed > 0
the Total validators = (2*appeal_failed*n+1)+2
of the row above, This is the formula from
and are in consensus_data. above and it is what is
For appeal_failed > 0 returned by this function
= (2*appeal_failed-1)*n+3
This is used to calculate n
Args:
snapshot (ChainSnapshot): Snapshot of the chain state.
Expand All @@ -638,37 +688,24 @@ def get_extra_validators(
# Get all validators
validators = snapshot.get_all_validators()

if appeal_failed > 0:
# Create a dictionary to map addresses to validator entries
validator_map = {
validator["address"]: validator for validator in validators
}

# List to store current validators for each receipt
current_validators = [
validator_map[consensus_data.leader_receipt.node_config["address"]]
]
else:
current_validators = []

# Set to track addresses found in receipts
receipt_addresses = set([consensus_data.leader_receipt.node_config["address"]])
# Create a dictionary to map addresses to validator entries
validator_map = {validator["address"]: validator for validator in validators}

# Iterate over receipts to find matching validators
for receipt in consensus_data.validators:
address = receipt.node_config["address"]
receipt_addresses.add(address)
if appeal_failed > 0:
if address in validator_map:
current_validators.append(validator_map[address])
# List containing addresses found in leader and validator receipts
receipt_addresses = [consensus_data.leader_receipt.node_config["address"]] + [
receipt.node_config["address"] for receipt in consensus_data.validators
]

# Get all validators where the address is not in the receipts
not_used_validators = [
validator
for validator in validators
if validator["address"] not in receipt_addresses
# Get leader and current validators from consensus data receipt addresses
current_validators = [
validator_map.pop(receipt_address)
for receipt_address in receipt_addresses
if receipt_address in validator_map
]

# Set not_used_validators to the remaining validators in validator_map
not_used_validators = list(validator_map.values())

if len(not_used_validators) == 0:
raise ValueError(
"No validators found for appeal, waiting for next appeal request: "
Expand Down Expand Up @@ -824,6 +861,15 @@ class PendingState(TransactionState):
Class representing the pending state of a transaction.
"""

def __init__(self, called_from_pending_queue: bool):
"""
Initialize the PendingState.
Args:
called_from_pending_queue (bool): Indicates if the PendingState was called from the pending queue.
"""
self.called_from_pending_queue = called_from_pending_queue

async def handle(self, context):
"""
Handle the pending state transition.
Expand All @@ -841,6 +887,11 @@ async def handle(self, context):
)
)

# Transaction should not be processed from the pending queue if it is a leader appeal
# This is to filter out the transaction picked up by _crawl_snapshot
if self.called_from_pending_queue and context.transaction.appeal_undetermined:
return None

if context.transaction.status != TransactionStatus.PENDING:
# This is a patch for a TOCTOU problem we have https://github.com/yeagerai/genlayer-simulator/issues/387
# Problem: Pending transactions are checked by `_crawl_snapshot`, which appends them to queues. These queues are consumed by `_run_consensus`, which processes the transactions. This means that a transaction can be processed multiple times, since `_crawl_snapshot` can append the same transaction to the queue multiple times.
Expand Down Expand Up @@ -1193,6 +1244,12 @@ async def handle(self, context):
)
context.transaction.appealed = False

# Set the transaction appeal undetermined status to false
context.transactions_processor.set_transaction_appeal_undetermined(
context.transaction.hash, False
)
context.transaction.appeal_undetermined = False

# Set the transaction result
context.transactions_processor.set_transaction_result(
context.transaction.hash, context.consensus_data.to_dict()
Expand Down Expand Up @@ -1222,6 +1279,9 @@ async def handle(self, context):
# Retrieve the leader's receipt from the consensus data
leader_receipt = context.consensus_data.leader_receipt

# Get the contract snapshot for the transaction's target address
leaders_contract_snapshot = context.contract_snapshot_supplier()

# Do not deploy the contract if the execution failed
if leader_receipt.execution_result == ExecutionResultStatus.SUCCESS:
# Get the contract snapshot for the transaction's target address
Expand Down Expand Up @@ -1344,26 +1404,64 @@ async def handle(self, context):
context.msg_handler,
)

# Insert pending transactions generated by contract-to-contract calls
pending_transactions = (
context.transaction.consensus_data.leader_receipt.pending_transactions
)
for pending_transaction in pending_transactions:
nonce = context.transactions_processor.get_transaction_count(
context.transaction.to_address
)
context.transactions_processor.insert_transaction(
context.transaction.to_address, # new calls are done by the contract
pending_transaction.address,
{
"calldata": pending_transaction.calldata,
},
value=0, # we only handle EOA transfers at the moment, so no value gets transferred
type=TransactionType.RUN_CONTRACT.value,
nonce=nonce,
leader_only=context.transaction.leader_only, # Cascade
triggered_by_hash=context.transaction.hash,
if context.transaction.status != TransactionStatus.UNDETERMINED:
# Insert pending transactions generated by contract-to-contract calls
pending_transactions = (
context.transaction.consensus_data.leader_receipt.pending_transactions
)
for pending_transaction in pending_transactions:
nonce = context.transactions_processor.get_transaction_count(
context.transaction.to_address
)
data: dict
transaction_type: TransactionType
if pending_transaction.is_deploy():
transaction_type = TransactionType.DEPLOY_CONTRACT
new_contract_address: str
if pending_transaction.salt_nonce == 0:
# NOTE: this address is random, which doesn't 100% align with consensus spec
new_contract_address = (
context.accounts_manager.create_new_account().address
)
else:
from eth_utils.crypto import keccak
from backend.node.types import Address
from backend.node.base import SIMULATOR_CHAIN_ID

arr = bytearray()
arr.append(1)
arr.extend(Address(context.transaction.to_address).as_bytes)
arr.extend(
pending_transaction.salt_nonce.to_bytes(
32, "big", signed=False
)
)
arr.extend(SIMULATOR_CHAIN_ID.to_bytes(32, "big", signed=False))
new_contract_address = Address(keccak(arr)[:20]).as_hex
context.accounts_manager.create_new_account_with_address(
new_contract_address
)
pending_transaction.address = new_contract_address
data = {
"contract_address": new_contract_address,
"contract_code": pending_transaction.code,
"calldata": pending_transaction.calldata,
}
else:
transaction_type = TransactionType.RUN_CONTRACT
data = {
"calldata": pending_transaction.calldata,
}
context.transactions_processor.insert_transaction(
context.transaction.to_address, # new calls are done by the contract
pending_transaction.address,
data,
value=0, # we only handle EOA transfers at the moment, so no value gets transferred
type=transaction_type.value,
nonce=nonce,
leader_only=context.transaction.leader_only, # Cascade
triggered_by_hash=context.transaction.hash,
)


def rotate(nodes: list) -> Iterator[list]:
Expand Down
2 changes: 1 addition & 1 deletion backend/database_handler/contract_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ContractSnapshot:

contract_address: str
contract_code: str
encoded_state: dict[str, dict[str, str]]
encoded_state: dict[str, str]

def __init__(self, contract_address: str | None, session: Session):
self.session = session
Expand Down
29 changes: 17 additions & 12 deletions backend/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from .types import Address

SIMULATOR_CHAIN_ID: typing.Final[int] = 61999


class _SnapshotView(genvmbase.StateProxy):
def __init__(
Expand All @@ -43,14 +45,14 @@ def _get_snapshot(self, addr: Address) -> ContractSnapshot:
return res

def get_code(self, addr: Address) -> bytes:
return self._get_snapshot(addr).contract_code.encode("utf-8")
return base64.b64decode(self._get_snapshot(addr).contract_code)

def storage_read(
self, account: Address, slot: bytes, index: int, le: int, /
) -> tuple[bytes, int]:
) -> bytes:
snap = self._get_snapshot(account)
for_acc = snap.encoded_state.setdefault(account.as_b64, {})
for_slot = for_acc.setdefault(base64.b64encode(slot).decode("ascii"), "")
slot_id = base64.b64encode(slot).decode("ascii")
for_slot = snap.encoded_state.setdefault(slot_id, "")
data = bytearray(base64.b64decode(for_slot))
data.extend(b"\x00" * (index + le - len(data)))
return data[index : index + le]
Expand All @@ -66,14 +68,13 @@ def storage_write(
assert account == self.contract_address
assert not self.readonly
snap = self._get_snapshot(account)
for_acc = snap.encoded_state.setdefault(account.as_b64, {})
slot_id = base64.b64encode(slot).decode("ascii")
for_slot = for_acc.setdefault(slot_id, "")
for_slot = snap.encoded_state.setdefault(slot_id, "")
data = bytearray(base64.b64decode(for_slot))
mem = memoryview(got)
data.extend(b"\x00" * (index + len(mem) - len(data)))
data[index : index + len(mem)] = mem
for_acc[slot_id] = base64.b64encode(data).decode("utf-8")
snap.encoded_state[slot_id] = base64.b64encode(data).decode("utf-8")


class Node:
Expand Down Expand Up @@ -102,10 +103,11 @@ async def exec_transaction(self, transaction: Transaction) -> Receipt:
transaction_data = transaction.data
assert transaction.from_address is not None
if transaction.type == TransactionType.DEPLOY_CONTRACT:
code = base64.b64decode(transaction_data["contract_code"])
calldata = base64.b64decode(transaction_data["calldata"])
receipt = await self.deploy_contract(
transaction.from_address,
transaction_data["contract_code"],
code,
calldata,
transaction.hash,
transaction.created_at,
Expand Down Expand Up @@ -149,13 +151,15 @@ def _date_from_str(self, date: str | None) -> datetime.datetime | None:
async def deploy_contract(
self,
from_address: str,
code_to_deploy: str,
code_to_deploy: bytes,
calldata: bytes,
transaction_hash: str,
transaction_created_at: str | None = None,
) -> Receipt:
assert self.contract_snapshot is not None
self.contract_snapshot.contract_code = code_to_deploy
self.contract_snapshot.contract_code = base64.b64encode(code_to_deploy).decode(
"ascii"
)
return await self._run_genvm(
from_address,
calldata,
Expand Down Expand Up @@ -220,9 +224,9 @@ async def _execution_finished(
)
)

async def get_contract_schema(self, code: str) -> str:
async def get_contract_schema(self, code: bytes) -> str:
genvm = self._create_genvm()
res = await genvm.get_contract_schema(code.encode("utf-8"))
res = await genvm.get_contract_schema(code)
await self._execution_finished(res, None)
err_data = {
"stdout": res.stdout,
Expand Down Expand Up @@ -298,6 +302,7 @@ async def _run_genvm(
leader_results=leader_res,
config=json.dumps(config),
date=transaction_datetime,
chain_id=SIMULATOR_CHAIN_ID,
)
await self._execution_finished(res, transaction_hash)

Expand Down
Loading

0 comments on commit c2d18ca

Please sign in to comment.