Skip to content

Commit

Permalink
Transition rolled back tasks into Completed(name="RolledBack") (#14721
Browse files Browse the repository at this point in the history
)
  • Loading branch information
collincchoy authored Jul 29, 2024
1 parent 9a4c901 commit fede95b
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 82 deletions.
22 changes: 20 additions & 2 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from prefect.states import (
AwaitingRetry,
Completed,
Failed,
Paused,
Pending,
Expand Down Expand Up @@ -244,6 +245,21 @@ def example_flow():
msg=msg,
)

def handle_rollback(self, txn: Transaction) -> None:
assert self.task_run is not None

rolled_back_state = Completed(
name="RolledBack",
message="Task rolled back as part of transaction",
)

self._last_event = emit_task_run_state_change_event(
task_run=self.task_run,
initial_state=self.state,
validated_state=rolled_back_state,
follows=self._last_event,
)


@dataclass
class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
Expand Down Expand Up @@ -463,7 +479,8 @@ def handle_success(self, result: R, transaction: Transaction) -> R:
)
transaction.stage(
terminal_state.data,
on_rollback_hooks=[
on_rollback_hooks=[self.handle_rollback]
+ [
_with_transaction_hook_logging(hook, "rollback", self.logger)
for hook in self.task.on_rollback_hooks
],
Expand Down Expand Up @@ -1013,7 +1030,8 @@ async def handle_success(self, result: R, transaction: Transaction) -> R:
)
transaction.stage(
terminal_state.data,
on_rollback_hooks=[
on_rollback_hooks=[self.handle_rollback]
+ [
_with_transaction_hook_logging(hook, "rollback", self.logger)
for hook in self.task.on_rollback_hooks
],
Expand Down
165 changes: 165 additions & 0 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from prefect.task_runners import ThreadPoolTaskRunner
from prefect.testing.utilities import exceptions_equal
from prefect.transactions import transaction
from prefect.utilities.callables import get_call_parameters
from prefect.utilities.engine import propose_state

Expand Down Expand Up @@ -2574,3 +2575,167 @@ def foo():
return proof_that_i_ran

assert the_flow() == proof_that_i_ran


class TestTransactionHooks:
async def test_task_transitions_to_rolled_back_on_transaction_rollback(
self,
events_pipeline,
prefect_client,
enable_client_side_task_run_orchestration,
):
if not enable_client_side_task_run_orchestration:
pytest.xfail(
"The Task Run Recorder is not enabled to handle state transitions via events"
)

task_run_state = None

@task
def foo():
pass

@foo.on_rollback
def rollback(txn):
pass

@flow
def txn_flow():
with transaction():
nonlocal task_run_state
task_run_state = foo(return_state=True)
raise ValueError("txn failed")

txn_flow(return_state=True)

task_run_id = task_run_state.state_details.task_run_id

await events_pipeline.process_events()
task_run_states = await prefect_client.read_task_run_states(task_run_id)

state_names = [state.name for state in task_run_states]
assert state_names == [
"Pending",
"Running",
"Completed",
"RolledBack",
]

async def test_task_transitions_to_rolled_back_on_transaction_rollback_async(
self,
events_pipeline,
prefect_client,
enable_client_side_task_run_orchestration,
):
if not enable_client_side_task_run_orchestration:
pytest.xfail(
"The Task Run Recorder is not enabled to handle state transitions via events"
)

task_run_state = None

@task
async def foo():
pass

@foo.on_rollback
def rollback(txn):
pass

@flow
async def txn_flow():
with transaction():
nonlocal task_run_state
task_run_state = await foo(return_state=True)
raise ValueError("txn failed")

await txn_flow(return_state=True)

task_run_id = task_run_state.state_details.task_run_id

await events_pipeline.process_events()
task_run_states = await prefect_client.read_task_run_states(task_run_id)

state_names = [state.name for state in task_run_states]
assert state_names == [
"Pending",
"Running",
"Completed",
"RolledBack",
]

def test_rollback_errors_are_logged(self, caplog):
@task
def foo():
pass

@foo.on_rollback
def rollback(txn):
raise RuntimeError("whoops!")

@flow
def txn_flow():
with transaction():
foo()
raise ValueError("txn failed")

txn_flow(return_state=True)
assert "An error was encountered while running rollback hook" in caplog.text
assert "RuntimeError" in caplog.text
assert "whoops!" in caplog.text

def test_rollback_hook_execution_and_completion_are_logged(self, caplog):
@task
def foo():
pass

@foo.on_rollback
def rollback(txn):
pass

@flow
def txn_flow():
with transaction():
foo()
raise ValueError("txn failed")

txn_flow(return_state=True)
assert "Running rollback hook 'rollback'" in caplog.text
assert "Rollback hook 'rollback' finished running successfully" in caplog.text

def test_commit_errors_are_logged(self, caplog):
@task
def foo():
pass

@foo.on_commit
def rollback(txn):
raise RuntimeError("whoops!")

@flow
def txn_flow():
with transaction():
foo()

txn_flow(return_state=True)
assert "An error was encountered while running commit hook" in caplog.text
assert "RuntimeError" in caplog.text
assert "whoops!" in caplog.text

def test_commit_hook_execution_and_completion_are_logged(self, caplog):
@task
def foo():
pass

@foo.on_commit
def commit(txn):
pass

@flow
def txn_flow():
with transaction():
foo()

txn_flow(return_state=True)
assert "Running commit hook 'commit'" in caplog.text
assert "Commit hook 'commit' finished running successfully" in caplog.text
80 changes: 0 additions & 80 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5104,83 +5104,3 @@ def add_em_up(*args, **kwargs):
assert await get_background_task_run_parameters(
add_em_up, future.state.state_details.task_parameters_id
) == {"parameters": {"args": (42,), "kwargs": {"y": 42}}, "context": ANY}


def test_rollback_errors_are_logged(caplog):
@task
def foo():
pass

@foo.on_rollback
def rollback(txn):
raise RuntimeError("whoops!")

@flow
def txn_flow():
with transaction():
foo()
raise ValueError("txn failed")

txn_flow(return_state=True)
assert "An error was encountered while running rollback hook" in caplog.text
assert "RuntimeError" in caplog.text
assert "whoops!" in caplog.text


def test_rollback_hook_execution_and_completion_are_logged(caplog):
@task
def foo():
pass

@foo.on_rollback
def rollback(txn):
pass

@flow
def txn_flow():
with transaction():
foo()
raise ValueError("txn failed")

txn_flow(return_state=True)
assert "Running rollback hook 'rollback'" in caplog.text
assert "Rollback hook 'rollback' finished running successfully" in caplog.text


def test_commit_errors_are_logged(caplog):
@task
def foo():
pass

@foo.on_commit
def rollback(txn):
raise RuntimeError("whoops!")

@flow
def txn_flow():
with transaction():
foo()

txn_flow(return_state=True)
assert "An error was encountered while running commit hook" in caplog.text
assert "RuntimeError" in caplog.text
assert "whoops!" in caplog.text


def test_commit_hook_execution_and_completion_are_logged(caplog):
@task
def foo():
pass

@foo.on_commit
def commit(txn):
pass

@flow
def txn_flow():
with transaction():
foo()

txn_flow(return_state=True)
assert "Running commit hook 'commit'" in caplog.text
assert "Commit hook 'commit' finished running successfully" in caplog.text

0 comments on commit fede95b

Please sign in to comment.