diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 242730cd5cf2..7c39dcc5d6c7 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -69,7 +69,7 @@ exception_to_failed_state, return_value_to_state, ) -from prefect.transactions import Transaction, TransactionState, transaction +from prefect.transactions import Transaction, transaction from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs @@ -403,7 +403,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R: ) transaction.stage( terminal_state.data, - on_rollback_hooks=[ + on_rollback_hooks=[self.handle_rollback] + + [ _with_transaction_hook_logging(hook, "rollback", self.logger) for hook in self.task.on_rollback_hooks ], @@ -515,6 +516,18 @@ def record_terminal_state_timing(self, state: State) -> None: state.timestamp - self.task_run.state.timestamp ) + def handle_rollback(self, txn: Transaction) -> None: + # transaction rollbacks can occur outside of the engine's context + # so we need to ensure a client is available + with ClientContext.get_or_create() as client_ctx: + self.set_state( + Completed( + name="RolledBack", + message="Task rolled back as part of transaction", + ), + client=client_ctx.sync_client, + ) + @contextmanager def setup_run_context(self, client: Optional[SyncPrefectClient] = None): from prefect.utilities.engine import ( @@ -725,22 +738,6 @@ def start( def transaction_context(self) -> Generator[Transaction, None, None]: result_factory = getattr(TaskRunContext.get(), "result_factory", None) - def __prefect_set_task_state_to_rolled_back(txn: Transaction): - if txn.state != TransactionState.STAGED: - # there's nothing to rollback so don't set the state to rolled back - # For example, a task that raises an unhandled exception while running - # does not trigger any on_rollback hooks b/c the transaction was never staged. - return - - with ClientContext.get_or_create() as client_ctx: - self.set_state( - Completed( - name="RolledBack", - message="Task rolled back as part of transaction", - ), - client=client_ctx.sync_client, - ) - # refresh cache setting is now repurposes as overwrite transaction record overwrite = ( self.task.refresh_cache @@ -752,7 +749,6 @@ def __prefect_set_task_state_to_rolled_back(txn: Transaction): store=ResultFactoryStore(result_factory=result_factory), overwrite=overwrite, logger=self.logger, - on_rollback_hooks=[__prefect_set_task_state_to_rolled_back], ) as txn: yield txn diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index 8a9e65e48be7..41c6c07d3001 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -254,7 +254,6 @@ def transaction( commit_mode: Optional[CommitMode] = None, overwrite: bool = False, logger: Optional[PrefectLogAdapter] = None, - **kwargs, ) -> Generator[Transaction, None, None]: """ A context manager for opening and managing a transaction. @@ -316,6 +315,5 @@ def transaction( commit_mode=commit_mode, overwrite=overwrite, logger=logger, - **kwargs, ) as txn: yield txn