Skip to content

Commit

Permalink
Simplify management of when the rollback state hook should run.
Browse files Browse the repository at this point in the history
  • Loading branch information
collincchoy committed Jul 23, 2024
1 parent d373d92 commit d100280
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
34 changes: 15 additions & 19 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
exception_to_failed_state,
return_value_to_state,
)
from prefect.transactions import Transaction, TransactionState, transaction
from prefect.transactions import Transaction, transaction
from prefect.utilities.annotations import NotSet
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
Expand Down Expand Up @@ -403,7 +403,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
)
transaction.stage(
terminal_state.data,
on_rollback_hooks=[
on_rollback_hooks=[self.handle_rollback]
+ [
_with_transaction_hook_logging(hook, "rollback", self.logger)
for hook in self.task.on_rollback_hooks
],
Expand Down Expand Up @@ -515,6 +516,18 @@ def record_terminal_state_timing(self, state: State) -> None:
state.timestamp - self.task_run.state.timestamp
)

def handle_rollback(self, txn: Transaction) -> None:
# transaction rollbacks can occur outside of the engine's context
# so we need to ensure a client is available
with ClientContext.get_or_create() as client_ctx:
self.set_state(
Completed(
name="RolledBack",
message="Task rolled back as part of transaction",
),
client=client_ctx.sync_client,
)

@contextmanager
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
from prefect.utilities.engine import (
Expand Down Expand Up @@ -725,22 +738,6 @@ def start(
def transaction_context(self) -> Generator[Transaction, None, None]:
result_factory = getattr(TaskRunContext.get(), "result_factory", None)

def __prefect_set_task_state_to_rolled_back(txn: Transaction):
if txn.state != TransactionState.STAGED:
# there's nothing to rollback so don't set the state to rolled back
# For example, a task that raises an unhandled exception while running
# does not trigger any on_rollback hooks b/c the transaction was never staged.
return

with ClientContext.get_or_create() as client_ctx:
self.set_state(
Completed(
name="RolledBack",
message="Task rolled back as part of transaction",
),
client=client_ctx.sync_client,
)

# refresh cache setting is now repurposes as overwrite transaction record
overwrite = (
self.task.refresh_cache
Expand All @@ -752,7 +749,6 @@ def __prefect_set_task_state_to_rolled_back(txn: Transaction):
store=ResultFactoryStore(result_factory=result_factory),
overwrite=overwrite,
logger=self.logger,
on_rollback_hooks=[__prefect_set_task_state_to_rolled_back],
) as txn:
yield txn

Expand Down
2 changes: 0 additions & 2 deletions src/prefect/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def transaction(
commit_mode: Optional[CommitMode] = None,
overwrite: bool = False,
logger: Optional[PrefectLogAdapter] = None,
**kwargs,
) -> Generator[Transaction, None, None]:
"""
A context manager for opening and managing a transaction.
Expand Down Expand Up @@ -316,6 +315,5 @@ def transaction(
commit_mode=commit_mode,
overwrite=overwrite,
logger=logger,
**kwargs,
) as txn:
yield txn

0 comments on commit d100280

Please sign in to comment.