diff --git a/src/prefecto/concurrency/__init__.py b/src/prefecto/concurrency/__init__.py new file mode 100644 index 0000000..46bee5a --- /dev/null +++ b/src/prefecto/concurrency/__init__.py @@ -0,0 +1,7 @@ +from prefecto.concurrency.batch_task import BatchTask +from prefecto.concurrency.kill_switch import ( + KillSwitch, + AnyFailedSwitch, + CountSwitch, + RateSwitch, +) diff --git a/src/prefecto/concurrency.py b/src/prefecto/concurrency/batch_task.py similarity index 92% rename from src/prefecto/concurrency.py rename to src/prefecto/concurrency/batch_task.py index 75cc4f4..fa3673a 100644 --- a/src/prefecto/concurrency.py +++ b/src/prefecto/concurrency/batch_task.py @@ -13,7 +13,8 @@ from prefect.utilities.callables import get_call_parameters from typing_extensions import ParamSpec -from . import logging, states +from prefecto import logging, states +from prefecto.concurrency.kill_switch import KillSwitch T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs R = TypeVar("R") # The return type of the user's function @@ -34,10 +35,18 @@ class BatchTask: """ - def __init__(self, task: Task[P, R], size: int): + def __init__( + self, task: Task[P, R], size: int, kill_switch: KillSwitch | None = None + ): self.task: Task = task self.size: int = size + if kill_switch is not None and not isinstance(kill_switch, KillSwitch): + raise TypeError( + f"Expected 'kill_switch' to be a subclass of 'KillSwitch', got {type(kill_switch)}." + ) + self._kill_switch = kill_switch + def _make_batches(self, **params: MapArgument) -> list[Batch]: """Create batches of arguments to pass to the `Task.map` calls. @@ -204,6 +213,10 @@ def _map(self, batches: list[Batch]) -> list[PrefectFuture]: if all(states.is_terminal(f.get_state()) for f in futures): is_processing = False + if self._kill_switch is not None: + for f in futures: + self._kill_switch.raise_if_triggered(f.get_state()) + # Map the last batch logger.debug( f"Mapping {self.task.name} batch {len(batches)} of {len(batches)}." diff --git a/src/prefecto/concurrency/kill_switch.py b/src/prefecto/concurrency/kill_switch.py new file mode 100644 index 0000000..e60c385 --- /dev/null +++ b/src/prefecto/concurrency/kill_switch.py @@ -0,0 +1,114 @@ +""" +Kill switch logic classes for stopping the execution of a `BatchTask`. + +""" + +import abc + +from prefect.states import State + + +class KillSwitchError(Exception): + """Error raised when a kill switch is activated.""" + + def __init__(self, message: str, ks: "KillSwitch"): + super().__init__(message) + self.ks = ks + + +class KillSwitch(abc.ABC): + """Abstract base class for a kill switch.""" + + @abc.abstractmethod + def should_flip_switch(self, state: State) -> bool: + """Check if this state should flip the kill switch. + + Returns: + `True` if the kill switch should be activated, `False` otherwise. + + """ + + @abc.abstractmethod + def raise_if_triggered(self, state: State): + """Check a state and raise a `KillSwitchError` if the kill switch has been activated. + + Raises: + KillSwitchError: If the kill switch has been activated. + + """ + + +class AnyFailedSwitch(KillSwitch): + """A kill switch that activates if any tasks failed.""" + + def should_flip_switch(self, state: State) -> bool: + """Check if the state is failed or crashed.""" + return state.is_failed() or state.is_crashed() + + def raise_if_triggered(self, state: State): + """Raise a `KillSwitchError` if the state is failed or crashed.""" + if self.should_flip_switch(state): + raise KillSwitchError("Failed task detected.", self) + + +class CountSwitch(KillSwitch): + """A kill switch that activates after a certain number of tasks fail. + + Args: + count (int): The number of states after which to activate the kill switch. + + """ + + def __init__(self, max_count: int): + self.max_count = max_count + self._current_count = 0 + + def should_flip_switch(self, state: State) -> bool: + """Increment the count if the state is failed or crashed and return if the count exceeds + the maximum. + """ + if state.is_failed() or state.is_crashed(): + self._current_count += 1 + return self._current_count >= self.max_count + + def raise_if_triggered(self, state: State): + """Raise a `KillSwitchError` if the count exceeds the maximum.""" + if self.should_flip_switch(state): + raise KillSwitchError(f"{self.max_count} failed tasks detected.", self) + + +class RateSwitch(KillSwitch): + """A kill switch that activates after the failure rate exceeds a certain threshold. + Requires a minimum number of states to sample. + + Args: + min_sample (int): The minimum number of states to sample. + max_fail_rate (float): The maximum frequency of failed states. + + """ + + def __init__(self, min_sample: int, max_fail_rate: float): + self.min_sample = min_sample + self.max_fail_rate = max_fail_rate + self._current_count = 0 + self._failed_count = 0 + + def should_flip_switch(self, state: State) -> bool: + """Increment the count if the state is failed or crashed and return if the failure rate + exceeds the maximum tolerable rate. + """ + self._current_count += 1 + if state.is_failed() or state.is_crashed(): + self._failed_count += 1 + return ( + self._current_count >= self.min_sample + and self._failed_count / self._current_count >= self.max_fail_rate + ) + + def raise_if_triggered(self, state: State): + """Raise a `KillSwitchError` if the failure rate exceeds the maximum tolerable rate.""" + if self.should_flip_switch(state): + raise KillSwitchError( + f"Failure rate exceeded {self.max_fail_rate} after {self.min_sample} samples.", + self, + ) diff --git a/tests/test_concurrency.py b/tests/concurrency/test_batch_task.py similarity index 63% rename from tests/test_concurrency.py rename to tests/concurrency/test_batch_task.py index 9f06c53..bc3aa83 100644 --- a/tests/test_concurrency.py +++ b/tests/concurrency/test_batch_task.py @@ -8,7 +8,7 @@ import pytest from prefect import flow, task, unmapped -from prefecto.concurrency import BatchTask +from prefecto.concurrency.batch_task import BatchTask @task @@ -77,3 +77,35 @@ def test() -> list[int]: result = test() assert result == expectation + + def test_map_with_kill_switch(self, harness): + """Test `BatchTask.map` with a kill switch.""" + from prefecto.concurrency.kill_switch import CountSwitch, KillSwitchError + + @flow + def test() -> list[int]: + """Test flow.""" + bt = BatchTask(add, 3, CountSwitch(2)) + bt.map([1, 2, 3, 4, 5, 6, 7, 8, 9], ["x", 1, 1, "y", 1, 1, 1, 1, 1]) + + with pytest.raises(KillSwitchError) as exc: + test() + assert isinstance(exc.value.ks, CountSwitch) + assert exc.value.ks._current_count == 2 + assert exc.value.ks._max_count == 2 + + def test_map_with_kill_switch_within_batch(self, harness): + """Test `BatchTask.map` with a kill switch.""" + from prefecto.concurrency.kill_switch import CountSwitch, KillSwitchError + + @flow + def test() -> list[int]: + """Test flow.""" + bt = BatchTask(add, 4, CountSwitch(2)) + bt.map([1, 2, 3, 4, 5], ["x", 1, "y", 1, 1]) + + with pytest.raises(KillSwitchError) as exc: + test() + assert isinstance(exc.value.ks, CountSwitch) + assert exc.value.ks._current_count == 2 + assert exc.value.ks._max_count == 2 diff --git a/tests/concurrency/test_kill_switch.py b/tests/concurrency/test_kill_switch.py new file mode 100644 index 0000000..3fc1ab6 --- /dev/null +++ b/tests/concurrency/test_kill_switch.py @@ -0,0 +1,33 @@ +import pytest +from prefect import states + +from prefecto.concurrency.kill_switch import ( + AnyFailedSwitch, + CountSwitch, + KillSwitchError, + RateSwitch, +) + + +@pytest.mark.asyncio +async def test_any_kill_switch(): + with pytest.raises(KillSwitchError): + AnyFailedSwitch().raise_if_triggered(states.Failed()) + + +@pytest.mark.asyncio +async def test_count_kill_switch(): + ks = CountSwitch(2) + ks.raise_if_triggered(states.Failed()) + ks.raise_if_triggered(states.Completed()) + with pytest.raises(KillSwitchError): + ks.raise_if_triggered(states.Failed()) + + +@pytest.mark.asyncio +async def test_rate_kill_switch(): + ks = RateSwitch(3, 0.5) + ks.raise_if_triggered(states.Completed()) + ks.raise_if_triggered(states.Failed()) + with pytest.raises(KillSwitchError): + ks.raise_if_triggered(states.Crashed())