Skip to content

Commit

Permalink
fix: add rollback to pending of newer transactions for leader appeals
Browse files Browse the repository at this point in the history
  • Loading branch information
kstroobants committed Dec 20, 2024
1 parent c2d18ca commit df3f7e4
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 138 deletions.
244 changes: 124 additions & 120 deletions backend/consensus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,79 @@ def contract_snapshot_factory(
return ContractSnapshot(contract_address, session)


class TransactionContext:
"""
Class representing the context of a transaction.
Attributes:
transaction (Transaction): The transaction.
transactions_processor (TransactionsProcessor): Instance responsible for handling transaction operations within the database.
snapshot (ChainSnapshot): Snapshot of the chain state.
accounts_manager (AccountsManager): Manager for accounts.
contract_snapshot_factory (Callable[[str], ContractSnapshot]): Factory function to create contract snapshots.
node_factory (Callable[[dict, ExecutionMode, ContractSnapshot, Receipt | None, MessageHandler, Callable[[str], ContractSnapshot]], Node]): Factory function to create nodes.
msg_handler (MessageHandler): Handler for messaging.
consensus_data (ConsensusData): Data related to the consensus process.
iterator_rotation (Iterator[list] | None): Iterator for rotating validators.
remaining_validators (list): List of remaining validators.
num_validators (int): Number of validators.
contract_snapshot (ContractSnapshot | None): Snapshot of the contract state.
votes (dict): Dictionary of votes.
validator_nodes (list): List of validator nodes.
validation_results (list): List of validation results.
"""

def __init__(
self,
transaction: Transaction,
transactions_processor: TransactionsProcessor,
snapshot: ChainSnapshot,
accounts_manager: AccountsManager,
contract_snapshot_factory: Callable[[str], ContractSnapshot],
node_factory: Callable[
[
dict,
ExecutionMode,
ContractSnapshot,
Receipt | None,
MessageHandler,
Callable[[str], ContractSnapshot],
],
Node,
],
msg_handler: MessageHandler,
):
"""
Initialize the TransactionContext.
Args:
transaction (Transaction): The transaction.
transactions_processor (TransactionsProcessor): Instance responsible for handling transaction operations within the database.
snapshot (ChainSnapshot): Snapshot of the chain state.
accounts_manager (AccountsManager): Manager for accounts.
contract_snapshot_factory (Callable[[str], ContractSnapshot]): Factory function to create contract snapshots.
node_factory (Callable[[dict, ExecutionMode, ContractSnapshot, Receipt | None, MessageHandler, Callable[[str], ContractSnapshot]], Node]): Factory function to create nodes.
msg_handler (MessageHandler): Handler for messaging.
"""
self.transaction = transaction
self.transactions_processor = transactions_processor
self.snapshot = snapshot
self.accounts_manager = accounts_manager
self.contract_snapshot_factory = contract_snapshot_factory
self.node_factory = node_factory
self.msg_handler = msg_handler
self.consensus_data = ConsensusData(
votes={}, leader_receipt=None, validators=[]
)
self.iterator_rotation: Iterator[list] | None = None
self.remaining_validators: list = []
self.num_validators: int = 0
self.contract_snapshot_supplier: Callable[[], ContractSnapshot] | None = None
self.votes: dict = {}
self.validator_nodes: list = []
self.validation_results: list = []


class ConsensusAlgorithm:
"""
Class representing the consensus algorithm.
Expand Down Expand Up @@ -542,6 +615,9 @@ async def _appeal_window(self):
next_state = await state.handle(context)
if next_state is None:
break
elif next_state == "leader_appeal_success":
self.rollback_transactions(context)
break
state = next_state
session.commit()

Expand Down Expand Up @@ -599,40 +675,14 @@ async def _appeal_window(self):
next_state = await state.handle(context)
if next_state is None:
break
elif isinstance(next_state, PendingState):
# Rollback all future transactions for the current contract
# Stop the _crawl_snapshot and the _run_consensus for the current contract
address = (
transaction.to_address
or transaction.from_address
)
self.stop_pending_queue_task(address)

# Wait until task is finished
while self.is_pending_queue_task_running(
address
):
time.sleep(1)

# Empty the pending queue
self.queues[address] = asyncio.Queue()

# Set all transactions with higher created_at to PENDING
future_transactions = transactions_processor.get_newer_transactions(
transaction.hash
elif next_state == "validator_appeal_success":
self.rollback_transactions(context)
ConsensusAlgorithm.dispatch_transaction_status_update(
context.transactions_processor,
context.transaction.hash,
TransactionStatus.PENDING,
context.msg_handler,
)
for (
future_transaction
) in future_transactions:
ConsensusAlgorithm.dispatch_transaction_status_update(
context.transactions_processor,
future_transaction["hash"],
TransactionStatus.PENDING,
context.msg_handler,
)

# Start the queue loop again
self.start_pending_queue_task(address)

# Transaction will be picked up by _crawl_snapshot
break
Expand All @@ -646,6 +696,37 @@ async def _appeal_window(self):
# Sleep for a short duration before the next iteration
await asyncio.sleep(1)

def rollback_transactions(self, context: TransactionContext):
"""
Rollback newer transactions.
"""
# Rollback all future transactions for the current contract
# Stop the _crawl_snapshot and the _run_consensus for the current contract
address = context.transaction.to_address or context.transaction.from_address
self.stop_pending_queue_task(address)

# Wait until task is finished
while self.is_pending_queue_task_running(address):
time.sleep(1)

# Empty the pending queue
self.queues[address] = asyncio.Queue()

# Set all transactions with higher created_at to PENDING
future_transactions = context.transactions_processor.get_newer_transactions(
context.transaction.hash
)
for future_transaction in future_transactions:
ConsensusAlgorithm.dispatch_transaction_status_update(
context.transactions_processor,
future_transaction["hash"],
TransactionStatus.PENDING,
context.msg_handler,
)

# Start the queue loop again
self.start_pending_queue_task(address)

@staticmethod
def get_extra_validators(
snapshot: ChainSnapshot, consensus_data: ConsensusData, appeal_failed: int
Expand Down Expand Up @@ -767,79 +848,6 @@ def set_finality_window_time(self, time: int):
self.finality_window_time = time


class TransactionContext:
"""
Class representing the context of a transaction.
Attributes:
transaction (Transaction): The transaction.
transactions_processor (TransactionsProcessor): Instance responsible for handling transaction operations within the database.
snapshot (ChainSnapshot): Snapshot of the chain state.
accounts_manager (AccountsManager): Manager for accounts.
contract_snapshot_factory (Callable[[str], ContractSnapshot]): Factory function to create contract snapshots.
node_factory (Callable[[dict, ExecutionMode, ContractSnapshot, Receipt | None, MessageHandler, Callable[[str], ContractSnapshot]], Node]): Factory function to create nodes.
msg_handler (MessageHandler): Handler for messaging.
consensus_data (ConsensusData): Data related to the consensus process.
iterator_rotation (Iterator[list] | None): Iterator for rotating validators.
remaining_validators (list): List of remaining validators.
num_validators (int): Number of validators.
contract_snapshot (ContractSnapshot | None): Snapshot of the contract state.
votes (dict): Dictionary of votes.
validator_nodes (list): List of validator nodes.
validation_results (list): List of validation results.
"""

def __init__(
self,
transaction: Transaction,
transactions_processor: TransactionsProcessor,
snapshot: ChainSnapshot,
accounts_manager: AccountsManager,
contract_snapshot_factory: Callable[[str], ContractSnapshot],
node_factory: Callable[
[
dict,
ExecutionMode,
ContractSnapshot,
Receipt | None,
MessageHandler,
Callable[[str], ContractSnapshot],
],
Node,
],
msg_handler: MessageHandler,
):
"""
Initialize the TransactionContext.
Args:
transaction (Transaction): The transaction.
transactions_processor (TransactionsProcessor): Instance responsible for handling transaction operations within the database.
snapshot (ChainSnapshot): Snapshot of the chain state.
accounts_manager (AccountsManager): Manager for accounts.
contract_snapshot_factory (Callable[[str], ContractSnapshot]): Factory function to create contract snapshots.
node_factory (Callable[[dict, ExecutionMode, ContractSnapshot, Receipt | None, MessageHandler, Callable[[str], ContractSnapshot]], Node]): Factory function to create nodes.
msg_handler (MessageHandler): Handler for messaging.
"""
self.transaction = transaction
self.transactions_processor = transactions_processor
self.snapshot = snapshot
self.accounts_manager = accounts_manager
self.contract_snapshot_factory = contract_snapshot_factory
self.node_factory = node_factory
self.msg_handler = msg_handler
self.consensus_data = ConsensusData(
votes={}, leader_receipt=None, validators=[]
)
self.iterator_rotation: Iterator[list] | None = None
self.remaining_validators: list = []
self.num_validators: int = 0
self.contract_snapshot_supplier: Callable[[], ContractSnapshot] | None = None
self.votes: dict = {}
self.validator_nodes: list = []
self.validation_results: list = []


class TransactionState(ABC):
"""
Abstract base class representing a state in the transaction process.
Expand Down Expand Up @@ -1186,17 +1194,11 @@ async def handle(self, context):
context.transactions_processor.set_transaction_result(
context.transaction.hash, context.consensus_data.to_dict()
)
ConsensusAlgorithm.dispatch_transaction_status_update(
context.transactions_processor,
context.transaction.hash,
TransactionStatus.PENDING,
context.msg_handler,
)
context.transactions_processor.set_transaction_appeal_failed(
context.transaction.hash,
0,
)
return PendingState()
return "validator_appeal_success"

else:
# Not appealed, update consensus data with current votes and validators
Expand Down Expand Up @@ -1244,12 +1246,6 @@ async def handle(self, context):
)
context.transaction.appealed = False

# Set the transaction appeal undetermined status to false
context.transactions_processor.set_transaction_appeal_undetermined(
context.transaction.hash, False
)
context.transaction.appeal_undetermined = False

# Set the transaction result
context.transactions_processor.set_transaction_result(
context.transaction.hash, context.consensus_data.to_dict()
Expand Down Expand Up @@ -1316,7 +1312,15 @@ async def handle(self, context):
leader_receipt.contract_state
)

return None
# Set the transaction appeal undetermined status to false and return appeal status
if context.transaction.appeal_undetermined:
context.transactions_processor.set_transaction_appeal_undetermined(
context.transaction.hash, False
)
context.transaction.appeal_undetermined = False
return "leader_appeal_success"
else:
return None


class UndeterminedState(TransactionState):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/consensus/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,7 +1741,7 @@ def get_vote():
node_factory=node_factory_supplier,
)

time.sleep(DEFAULT_FINALITY_WINDOW + 5)
time.sleep(DEFAULT_FINALITY_WINDOW + 7)

assert (
transactions_processor.get_transaction_by_hash(transaction.hash)["status"]
Expand Down
Loading

0 comments on commit df3f7e4

Please sign in to comment.