diff --git a/src/prefecto/__init__.py b/src/prefecto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/prefecto/concurrency.py b/src/prefecto/concurrency.py new file mode 100644 index 0000000..3610060 --- /dev/null +++ b/src/prefecto/concurrency.py @@ -0,0 +1,167 @@ +""" +Tools to improve Prefect concurrently. + +""" +from typing import Any, TypeVar + +from prefect.futures import PrefectFuture +from prefect.tasks import Task +from prefect.utilities.callables import get_call_parameters +from typing_extensions import ParamSpec + +from . import logging, states + +T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs +R = TypeVar("R") # The return type of the user's function +P = ParamSpec("P") # The parameters of the task + + +class BatchTask: + """Wraps a `Task` to perform `Task.map` in batches.""" + + def __init__(self, task: Task[P, R], size: int): + """Create a `BatchTask` to wrap a `Task` and perform `Task.map` in batches. + + Parameters + ---------- + task : Task[P, R] + The task to wrap. + size : int + The size of the batches to perform `Task.map` on. + + Examples + -------- + >>> from prefect import task + >>> from prefecto.concurrency import BatchTask + >>> @task + ... def add(a, b): + ... return a + b + >>> batch_add = BatchTask(add, 3) + >>> batch_add.map([1,2,3,4,5], [2,3,4,5,6]) + [ + PrefectFuture, + PrefectFuture, + PrefectFuture, + PrefectFuture, + PrefectFuture + ] + """ + self.task: Task = task + self.size: int = size + + def _make_batches(self, **params) -> list[dict[str, list[Any]]]: + """Create batches of arguments to pass to the `Task.map` calls. + + Parameters + ---------- + **params + Keyword arguments where each value is an iterable of equal length. Should + be at least one keyword argument. + + Returns + ------- + list[dict[str, list[Any]]] + A list of dictionaries where each dictionary has the same keys as the + provided keyword arguments. The values of the dictionaries are lists with + lengths no greater than `BatchTask.size`. + + Examples + -------- + + >>> BatchTask(task, 3)._make_batches(a=[1,2,3,4,5], b=[2,3,4,5,6]) + [ + {"a": [1,2,3], "b": [2,3,4]}, + {"a": [4,5], "b": [4,5,6]} + ] + """ + parameters = sorted(params.keys()) + if len(parameters) == 0: + raise ValueError("Must provide at least one iterable.") + + # Validate all are iterables + for k in parameters: + if not hasattr(params[k], "__iter__"): + raise ValueError(f"Expected '{k}' to be an iterable.") + + length = len(params[parameters[0]]) + + # Assure all of equal length + if len(parameters) > 1: + for k in parameters[1:]: + if not len(params[k]) == length: + raise ValueError( + f"Expected all iterables to be of length {length} like " + f"'{parameters[0]}'. '{k}' is length {len(params[k])}." + ) + + batches = [] + + for i in range(length // self.size): + batch = {p: [] for p in parameters} + for p in parameters: + batch[p] = params[p][i * self.size : (i + 1) * self.size] + batches.append(batch) + + # Add the remainder if there is one + if length % self.size != 0: + batch = {p: [] for p in parameters} + for p in parameters: + batch[p] = params[p][(i + 1) * self.size :] + batches.append(batch) + + return batches + + def map(self, *args, **kwds) -> list[PrefectFuture]: + """Perform a `Task.map` operation in batches of the keyword arguments. The + arguments must be iterables of equal length. + + Parameters + ---------- + *args + Positional arguments to pass to the task. + **kwds + Keyword arguments to pass to the task. + + Returns + ------- + list[PrefectFuture] + A list of futures for each batch. + """ + parameters = get_call_parameters(self.task.fn, args, kwds, apply_defaults=False) + batches = self._make_batches(**parameters) + + return self._map(batches) + + def _map(self, batches: list[dict[str, list[Any]]]) -> list[PrefectFuture]: + """Applies `Task.map` to each batch. + + Args: + batches (list[dict[str, list[Any]]]): _description_ + + Returns: + list[PrefectFuture]: _description_ + """ + logger = logging.get_prefect_or_default_logger() + results: list[PrefectFuture] = [] + for i, batch in enumerate(batches[:-1]): + logger.debug(f"Mapping {self.task.name} batch {i+1} of {len(batches)}.") + # Map the batch + futures = self.task.map(**batch) + results.extend(futures) + # Poll futures to ensure they are not active. + is_processing: bool = False + while is_processing: + for f in futures: + if not states.is_terminal(f.get_state()): + # If any future is still processing, break and poll again. + is_processing = True + break + else: + is_processing = False + + # Map the last batch + logger.debug( + f"Mapping {self.task.name} batch {len(batches)} of {len(batches)}." + ) + results.extend(self.task.map(**batches[-1])) + return results diff --git a/src/prefecto/logging.py b/src/prefecto/logging.py new file mode 100644 index 0000000..ba6c03e --- /dev/null +++ b/src/prefecto/logging.py @@ -0,0 +1,26 @@ +""" +Prefect logging utilities. +""" +import logging + +from prefect.logging import get_run_logger + + +def get_prefect_or_default_logger( + __default: logging.Logger | str | None = None, +) -> logging.Logger: + """Gets the Prefect logger if the global context is set. Returns the `__default` or + root logger if not. + """ + # Type check the default logger + if not isinstance(__default, (logging.Logger, str, type(None))): + raise TypeError( + f"Expected `__default` to be a `logging.Logger`, `str`, or `None`, " + f"got `{type(__default).__name__}`." + ) + try: + return get_run_logger() + except RuntimeError: + if isinstance(__default, str): + return logging.getLogger(__default) + return __default or logging.getLogger() diff --git a/src/prefecto/states.py b/src/prefecto/states.py new file mode 100644 index 0000000..a5912ec --- /dev/null +++ b/src/prefecto/states.py @@ -0,0 +1,25 @@ +""" +Tools to improve Prefect states. + +""" +from prefect import states + + +def is_terminal(state: states.State) -> bool: + """Return True if the state is terminal. Terminal states are: + + - Cancelled + - Completed + - Crashed + - Failed + """ + TERMINALS = [ + state.is_cancelled, + state.is_completed, + state.is_crashed, + state.is_failed, + ] + for terminal in TERMINALS: + if terminal(): + return True + return False diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000..91701ec --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,39 @@ +""" +Unit tests for the `concurrency` module. + +""" +from prefect import flow, task + +from prefecto.concurrency import BatchTask + + +@task +def add(a, b): + """Add two numbers.""" + return a + b + + +class TestBatchTask: + """Unit tests for `BatchTask`.""" + + def test_make_batches(self): + """Test `_make_batches`.""" + batches = BatchTask(add, 3)._make_batches(a=[1, 2, 3, 4, 5], b=[2, 3, 4, 5, 6]) + assert batches == [{"a": [1, 2, 3], "b": [2, 3, 4]}, {"a": [4, 5], "b": [5, 6]}] + + def test_map(self, harness): + """Test `BatchTask.map`.""" + + @task + def realize(futures: list[int]): + """Converts futures to their values.""" + return futures + + @flow + def test() -> list[int]: + """Test flow.""" + futures = BatchTask(add, 3).map([1, 2, 3, 4, 5], [2, 3, 4, 5, 6]) + return realize(futures) + + result = test() + assert result == [3, 5, 7, 9, 11] diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..e2b5fef --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,16 @@ +""" +Tests for the logging module. +""" +import logging + +from prefecto.logging import get_prefect_or_default_logger + + +def test_get_prefect_or_default_logger(): + """Tests `get_prefect_or_default_logger`.""" + assert get_prefect_or_default_logger().__class__ == logging.RootLogger + assert logging.getLogger("not root").__class__ == logging.Logger + assert ( + get_prefect_or_default_logger(logging.Logger("not root")).__class__ + == logging.Logger + )