diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index cec9eb7e4e7f..4868647be24c 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -62,6 +62,7 @@ ) from prefect.states import ( AwaitingRetry, + Completed, Failed, Paused, Pending, @@ -244,6 +245,21 @@ def example_flow(): msg=msg, ) + def handle_rollback(self, txn: Transaction) -> None: + assert self.task_run is not None + + rolled_back_state = Completed( + name="RolledBack", + message="Task rolled back as part of transaction", + ) + + self._last_event = emit_task_run_state_change_event( + task_run=self.task_run, + initial_state=self.state, + validated_state=rolled_back_state, + follows=self._last_event, + ) + @dataclass class SyncTaskRunEngine(BaseTaskRunEngine[P, R]): @@ -463,7 +479,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R: ) transaction.stage( terminal_state.data, - on_rollback_hooks=[ + on_rollback_hooks=[self.handle_rollback] + + [ _with_transaction_hook_logging(hook, "rollback", self.logger) for hook in self.task.on_rollback_hooks ], @@ -1013,7 +1030,8 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: ) transaction.stage( terminal_state.data, - on_rollback_hooks=[ + on_rollback_hooks=[self.handle_rollback] + + [ _with_transaction_hook_logging(hook, "rollback", self.logger) for hook in self.task.on_rollback_hooks ], diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 2d2c94765f09..c51411eb68c6 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -46,6 +46,7 @@ ) from prefect.task_runners import ThreadPoolTaskRunner from prefect.testing.utilities import exceptions_equal +from prefect.transactions import transaction from prefect.utilities.callables import get_call_parameters from prefect.utilities.engine import propose_state @@ -2574,3 +2575,167 @@ def foo(): return proof_that_i_ran assert the_flow() == proof_that_i_ran + + +class TestTransactionHooks: + async def test_task_transitions_to_rolled_back_on_transaction_rollback( + self, + events_pipeline, + prefect_client, + enable_client_side_task_run_orchestration, + ): + if not enable_client_side_task_run_orchestration: + pytest.xfail( + "The Task Run Recorder is not enabled to handle state transitions via events" + ) + + task_run_state = None + + @task + def foo(): + pass + + @foo.on_rollback + def rollback(txn): + pass + + @flow + def txn_flow(): + with transaction(): + nonlocal task_run_state + task_run_state = foo(return_state=True) + raise ValueError("txn failed") + + txn_flow(return_state=True) + + task_run_id = task_run_state.state_details.task_run_id + + await events_pipeline.process_events() + task_run_states = await prefect_client.read_task_run_states(task_run_id) + + state_names = [state.name for state in task_run_states] + assert state_names == [ + "Pending", + "Running", + "Completed", + "RolledBack", + ] + + async def test_task_transitions_to_rolled_back_on_transaction_rollback_async( + self, + events_pipeline, + prefect_client, + enable_client_side_task_run_orchestration, + ): + if not enable_client_side_task_run_orchestration: + pytest.xfail( + "The Task Run Recorder is not enabled to handle state transitions via events" + ) + + task_run_state = None + + @task + async def foo(): + pass + + @foo.on_rollback + def rollback(txn): + pass + + @flow + async def txn_flow(): + with transaction(): + nonlocal task_run_state + task_run_state = await foo(return_state=True) + raise ValueError("txn failed") + + await txn_flow(return_state=True) + + task_run_id = task_run_state.state_details.task_run_id + + await events_pipeline.process_events() + task_run_states = await prefect_client.read_task_run_states(task_run_id) + + state_names = [state.name for state in task_run_states] + assert state_names == [ + "Pending", + "Running", + "Completed", + "RolledBack", + ] + + def test_rollback_errors_are_logged(self, caplog): + @task + def foo(): + pass + + @foo.on_rollback + def rollback(txn): + raise RuntimeError("whoops!") + + @flow + def txn_flow(): + with transaction(): + foo() + raise ValueError("txn failed") + + txn_flow(return_state=True) + assert "An error was encountered while running rollback hook" in caplog.text + assert "RuntimeError" in caplog.text + assert "whoops!" in caplog.text + + def test_rollback_hook_execution_and_completion_are_logged(self, caplog): + @task + def foo(): + pass + + @foo.on_rollback + def rollback(txn): + pass + + @flow + def txn_flow(): + with transaction(): + foo() + raise ValueError("txn failed") + + txn_flow(return_state=True) + assert "Running rollback hook 'rollback'" in caplog.text + assert "Rollback hook 'rollback' finished running successfully" in caplog.text + + def test_commit_errors_are_logged(self, caplog): + @task + def foo(): + pass + + @foo.on_commit + def rollback(txn): + raise RuntimeError("whoops!") + + @flow + def txn_flow(): + with transaction(): + foo() + + txn_flow(return_state=True) + assert "An error was encountered while running commit hook" in caplog.text + assert "RuntimeError" in caplog.text + assert "whoops!" in caplog.text + + def test_commit_hook_execution_and_completion_are_logged(self, caplog): + @task + def foo(): + pass + + @foo.on_commit + def commit(txn): + pass + + @flow + def txn_flow(): + with transaction(): + foo() + + txn_flow(return_state=True) + assert "Running commit hook 'commit'" in caplog.text + assert "Commit hook 'commit' finished running successfully" in caplog.text diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 1210dea1e08a..ae81fe638b74 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -5104,83 +5104,3 @@ def add_em_up(*args, **kwargs): assert await get_background_task_run_parameters( add_em_up, future.state.state_details.task_parameters_id ) == {"parameters": {"args": (42,), "kwargs": {"y": 42}}, "context": ANY} - - -def test_rollback_errors_are_logged(caplog): - @task - def foo(): - pass - - @foo.on_rollback - def rollback(txn): - raise RuntimeError("whoops!") - - @flow - def txn_flow(): - with transaction(): - foo() - raise ValueError("txn failed") - - txn_flow(return_state=True) - assert "An error was encountered while running rollback hook" in caplog.text - assert "RuntimeError" in caplog.text - assert "whoops!" in caplog.text - - -def test_rollback_hook_execution_and_completion_are_logged(caplog): - @task - def foo(): - pass - - @foo.on_rollback - def rollback(txn): - pass - - @flow - def txn_flow(): - with transaction(): - foo() - raise ValueError("txn failed") - - txn_flow(return_state=True) - assert "Running rollback hook 'rollback'" in caplog.text - assert "Rollback hook 'rollback' finished running successfully" in caplog.text - - -def test_commit_errors_are_logged(caplog): - @task - def foo(): - pass - - @foo.on_commit - def rollback(txn): - raise RuntimeError("whoops!") - - @flow - def txn_flow(): - with transaction(): - foo() - - txn_flow(return_state=True) - assert "An error was encountered while running commit hook" in caplog.text - assert "RuntimeError" in caplog.text - assert "whoops!" in caplog.text - - -def test_commit_hook_execution_and_completion_are_logged(caplog): - @task - def foo(): - pass - - @foo.on_commit - def commit(txn): - pass - - @flow - def txn_flow(): - with transaction(): - foo() - - txn_flow(return_state=True) - assert "Running commit hook 'commit'" in caplog.text - assert "Commit hook 'commit' finished running successfully" in caplog.text