Skip to content

Commit

Permalink
Add kill_switch parameter to BatchTask to prevent queueing excess…
Browse files Browse the repository at this point in the history
…ive, failing tasks (#30)
  • Loading branch information
dominictarro authored Apr 15, 2024
1 parent 103c0b9 commit 7cc6631
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/prefecto/concurrency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from prefecto.concurrency.batch_task import BatchTask
from prefecto.concurrency.kill_switch import (
KillSwitch,
AnyFailedSwitch,
CountSwitch,
RateSwitch,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from prefect.utilities.callables import get_call_parameters
from typing_extensions import ParamSpec

from . import logging, states
from prefecto import logging, states
from prefecto.concurrency.kill_switch import KillSwitch

T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
R = TypeVar("R") # The return type of the user's function
Expand All @@ -34,10 +35,18 @@ class BatchTask:
"""

def __init__(self, task: Task[P, R], size: int):
def __init__(
self, task: Task[P, R], size: int, kill_switch: KillSwitch | None = None
):
self.task: Task = task
self.size: int = size

if kill_switch is not None and not isinstance(kill_switch, KillSwitch):
raise TypeError(
f"Expected 'kill_switch' to be a subclass of 'KillSwitch', got {type(kill_switch)}."
)
self._kill_switch = kill_switch

def _make_batches(self, **params: MapArgument) -> list[Batch]:
"""Create batches of arguments to pass to the `Task.map` calls.
Expand Down Expand Up @@ -204,6 +213,10 @@ def _map(self, batches: list[Batch]) -> list[PrefectFuture]:
if all(states.is_terminal(f.get_state()) for f in futures):
is_processing = False

if self._kill_switch is not None:
for f in futures:
self._kill_switch.raise_if_triggered(f.get_state())

# Map the last batch
logger.debug(
f"Mapping {self.task.name} batch {len(batches)} of {len(batches)}."
Expand Down
114 changes: 114 additions & 0 deletions src/prefecto/concurrency/kill_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Kill switch logic classes for stopping the execution of a `BatchTask`.
"""

import abc

from prefect.states import State


class KillSwitchError(Exception):
"""Error raised when a kill switch is activated."""

def __init__(self, message: str, ks: "KillSwitch"):
super().__init__(message)
self.ks = ks


class KillSwitch(abc.ABC):
"""Abstract base class for a kill switch."""

@abc.abstractmethod
def should_flip_switch(self, state: State) -> bool:
"""Check if this state should flip the kill switch.
Returns:
`True` if the kill switch should be activated, `False` otherwise.
"""

@abc.abstractmethod
def raise_if_triggered(self, state: State):
"""Check a state and raise a `KillSwitchError` if the kill switch has been activated.
Raises:
KillSwitchError: If the kill switch has been activated.
"""


class AnyFailedSwitch(KillSwitch):
"""A kill switch that activates if any tasks failed."""

def should_flip_switch(self, state: State) -> bool:
"""Check if the state is failed or crashed."""
return state.is_failed() or state.is_crashed()

def raise_if_triggered(self, state: State):
"""Raise a `KillSwitchError` if the state is failed or crashed."""
if self.should_flip_switch(state):
raise KillSwitchError("Failed task detected.", self)


class CountSwitch(KillSwitch):
"""A kill switch that activates after a certain number of tasks fail.
Args:
count (int): The number of states after which to activate the kill switch.
"""

def __init__(self, max_count: int):
self.max_count = max_count
self._current_count = 0

def should_flip_switch(self, state: State) -> bool:
"""Increment the count if the state is failed or crashed and return if the count exceeds
the maximum.
"""
if state.is_failed() or state.is_crashed():
self._current_count += 1
return self._current_count >= self.max_count

def raise_if_triggered(self, state: State):
"""Raise a `KillSwitchError` if the count exceeds the maximum."""
if self.should_flip_switch(state):
raise KillSwitchError(f"{self.max_count} failed tasks detected.", self)


class RateSwitch(KillSwitch):
"""A kill switch that activates after the failure rate exceeds a certain threshold.
Requires a minimum number of states to sample.
Args:
min_sample (int): The minimum number of states to sample.
max_fail_rate (float): The maximum frequency of failed states.
"""

def __init__(self, min_sample: int, max_fail_rate: float):
self.min_sample = min_sample
self.max_fail_rate = max_fail_rate
self._current_count = 0
self._failed_count = 0

def should_flip_switch(self, state: State) -> bool:
"""Increment the count if the state is failed or crashed and return if the failure rate
exceeds the maximum tolerable rate.
"""
self._current_count += 1
if state.is_failed() or state.is_crashed():
self._failed_count += 1
return (
self._current_count >= self.min_sample
and self._failed_count / self._current_count >= self.max_fail_rate
)

def raise_if_triggered(self, state: State):
"""Raise a `KillSwitchError` if the failure rate exceeds the maximum tolerable rate."""
if self.should_flip_switch(state):
raise KillSwitchError(
f"Failure rate exceeded {self.max_fail_rate} after {self.min_sample} samples.",
self,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from prefect import flow, task, unmapped

from prefecto.concurrency import BatchTask
from prefecto.concurrency.batch_task import BatchTask


@task
Expand Down Expand Up @@ -77,3 +77,35 @@ def test() -> list[int]:

result = test()
assert result == expectation

def test_map_with_kill_switch(self, harness):
"""Test `BatchTask.map` with a kill switch."""
from prefecto.concurrency.kill_switch import CountSwitch, KillSwitchError

@flow
def test() -> list[int]:
"""Test flow."""
bt = BatchTask(add, 3, CountSwitch(2))
bt.map([1, 2, 3, 4, 5, 6, 7, 8, 9], ["x", 1, 1, "y", 1, 1, 1, 1, 1])

with pytest.raises(KillSwitchError) as exc:
test()
assert isinstance(exc.value.ks, CountSwitch)
assert exc.value.ks._current_count == 2
assert exc.value.ks._max_count == 2

def test_map_with_kill_switch_within_batch(self, harness):
"""Test `BatchTask.map` with a kill switch."""
from prefecto.concurrency.kill_switch import CountSwitch, KillSwitchError

@flow
def test() -> list[int]:
"""Test flow."""
bt = BatchTask(add, 4, CountSwitch(2))
bt.map([1, 2, 3, 4, 5], ["x", 1, "y", 1, 1])

with pytest.raises(KillSwitchError) as exc:
test()
assert isinstance(exc.value.ks, CountSwitch)
assert exc.value.ks._current_count == 2
assert exc.value.ks._max_count == 2
33 changes: 33 additions & 0 deletions tests/concurrency/test_kill_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from prefect import states

from prefecto.concurrency.kill_switch import (
AnyFailedSwitch,
CountSwitch,
KillSwitchError,
RateSwitch,
)


@pytest.mark.asyncio
async def test_any_kill_switch():
with pytest.raises(KillSwitchError):
AnyFailedSwitch().raise_if_triggered(states.Failed())


@pytest.mark.asyncio
async def test_count_kill_switch():
ks = CountSwitch(2)
ks.raise_if_triggered(states.Failed())
ks.raise_if_triggered(states.Completed())
with pytest.raises(KillSwitchError):
ks.raise_if_triggered(states.Failed())


@pytest.mark.asyncio
async def test_rate_kill_switch():
ks = RateSwitch(3, 0.5)
ks.raise_if_triggered(states.Completed())
ks.raise_if_triggered(states.Failed())
with pytest.raises(KillSwitchError):
ks.raise_if_triggered(states.Crashed())

0 comments on commit 7cc6631

Please sign in to comment.