Skip to content

Commit

Permalink
batch mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
dominictarro committed Jun 13, 2023
1 parent d111100 commit 4a03350
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 0 deletions.
Empty file added src/prefecto/__init__.py
Empty file.
167 changes: 167 additions & 0 deletions src/prefecto/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
Tools to improve Prefect concurrently.
"""
from typing import Any, TypeVar

from prefect.futures import PrefectFuture
from prefect.tasks import Task
from prefect.utilities.callables import get_call_parameters
from typing_extensions import ParamSpec

from . import logging, states

T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs
R = TypeVar("R") # The return type of the user's function
P = ParamSpec("P") # The parameters of the task


class BatchTask:
"""Wraps a `Task` to perform `Task.map` in batches."""

def __init__(self, task: Task[P, R], size: int):
"""Create a `BatchTask` to wrap a `Task` and perform `Task.map` in batches.
Parameters
----------
task : Task[P, R]
The task to wrap.
size : int
The size of the batches to perform `Task.map` on.
Examples
--------
>>> from prefect import task
>>> from prefecto.concurrency import BatchTask
>>> @task
... def add(a, b):
... return a + b
>>> batch_add = BatchTask(add, 3)
>>> batch_add.map([1,2,3,4,5], [2,3,4,5,6])
[
PrefectFuture<Success: 3>,
PrefectFuture<Success: 5>,
PrefectFuture<Success: 7>,
PrefectFuture<Success: 9>,
PrefectFuture<Success: 11>
]
"""
self.task: Task = task
self.size: int = size

def _make_batches(self, **params) -> list[dict[str, list[Any]]]:
"""Create batches of arguments to pass to the `Task.map` calls.
Parameters
----------
**params
Keyword arguments where each value is an iterable of equal length. Should
be at least one keyword argument.
Returns
-------
list[dict[str, list[Any]]]
A list of dictionaries where each dictionary has the same keys as the
provided keyword arguments. The values of the dictionaries are lists with
lengths no greater than `BatchTask.size`.
Examples
--------
>>> BatchTask(task, 3)._make_batches(a=[1,2,3,4,5], b=[2,3,4,5,6])
[
{"a": [1,2,3], "b": [2,3,4]},
{"a": [4,5], "b": [4,5,6]}
]
"""
parameters = sorted(params.keys())
if len(parameters) == 0:
raise ValueError("Must provide at least one iterable.")

# Validate all are iterables
for k in parameters:
if not hasattr(params[k], "__iter__"):
raise ValueError(f"Expected '{k}' to be an iterable.")

length = len(params[parameters[0]])

# Assure all of equal length
if len(parameters) > 1:
for k in parameters[1:]:
if not len(params[k]) == length:
raise ValueError(
f"Expected all iterables to be of length {length} like "
f"'{parameters[0]}'. '{k}' is length {len(params[k])}."
)

batches = []

for i in range(length // self.size):
batch = {p: [] for p in parameters}
for p in parameters:
batch[p] = params[p][i * self.size : (i + 1) * self.size]
batches.append(batch)

# Add the remainder if there is one
if length % self.size != 0:
batch = {p: [] for p in parameters}
for p in parameters:
batch[p] = params[p][(i + 1) * self.size :]
batches.append(batch)

return batches

def map(self, *args, **kwds) -> list[PrefectFuture]:
"""Perform a `Task.map` operation in batches of the keyword arguments. The
arguments must be iterables of equal length.
Parameters
----------
*args
Positional arguments to pass to the task.
**kwds
Keyword arguments to pass to the task.
Returns
-------
list[PrefectFuture]
A list of futures for each batch.
"""
parameters = get_call_parameters(self.task.fn, args, kwds, apply_defaults=False)
batches = self._make_batches(**parameters)

return self._map(batches)

def _map(self, batches: list[dict[str, list[Any]]]) -> list[PrefectFuture]:
"""Applies `Task.map` to each batch.
Args:
batches (list[dict[str, list[Any]]]): _description_
Returns:
list[PrefectFuture]: _description_
"""
logger = logging.get_prefect_or_default_logger()
results: list[PrefectFuture] = []
for i, batch in enumerate(batches[:-1]):
logger.debug(f"Mapping {self.task.name} batch {i+1} of {len(batches)}.")
# Map the batch
futures = self.task.map(**batch)
results.extend(futures)
# Poll futures to ensure they are not active.
is_processing: bool = False
while is_processing:
for f in futures:
if not states.is_terminal(f.get_state()):
# If any future is still processing, break and poll again.
is_processing = True
break
else:
is_processing = False

# Map the last batch
logger.debug(
f"Mapping {self.task.name} batch {len(batches)} of {len(batches)}."
)
results.extend(self.task.map(**batches[-1]))
return results
26 changes: 26 additions & 0 deletions src/prefecto/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Prefect logging utilities.
"""
import logging

from prefect.logging import get_run_logger


def get_prefect_or_default_logger(
__default: logging.Logger | str | None = None,
) -> logging.Logger:
"""Gets the Prefect logger if the global context is set. Returns the `__default` or
root logger if not.
"""
# Type check the default logger
if not isinstance(__default, (logging.Logger, str, type(None))):
raise TypeError(
f"Expected `__default` to be a `logging.Logger`, `str`, or `None`, "
f"got `{type(__default).__name__}`."
)
try:
return get_run_logger()
except RuntimeError:
if isinstance(__default, str):
return logging.getLogger(__default)
return __default or logging.getLogger()
25 changes: 25 additions & 0 deletions src/prefecto/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Tools to improve Prefect states.
"""
from prefect import states


def is_terminal(state: states.State) -> bool:
"""Return True if the state is terminal. Terminal states are:
- Cancelled
- Completed
- Crashed
- Failed
"""
TERMINALS = [
state.is_cancelled,
state.is_completed,
state.is_crashed,
state.is_failed,
]
for terminal in TERMINALS:
if terminal():
return True
return False
39 changes: 39 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Unit tests for the `concurrency` module.
"""
from prefect import flow, task

from prefecto.concurrency import BatchTask


@task
def add(a, b):
"""Add two numbers."""
return a + b


class TestBatchTask:
"""Unit tests for `BatchTask`."""

def test_make_batches(self):
"""Test `_make_batches`."""
batches = BatchTask(add, 3)._make_batches(a=[1, 2, 3, 4, 5], b=[2, 3, 4, 5, 6])
assert batches == [{"a": [1, 2, 3], "b": [2, 3, 4]}, {"a": [4, 5], "b": [5, 6]}]

def test_map(self, harness):
"""Test `BatchTask.map`."""

@task
def realize(futures: list[int]):
"""Converts futures to their values."""
return futures

@flow
def test() -> list[int]:
"""Test flow."""
futures = BatchTask(add, 3).map([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])
return realize(futures)

result = test()
assert result == [3, 5, 7, 9, 11]
16 changes: 16 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Tests for the logging module.
"""
import logging

from prefecto.logging import get_prefect_or_default_logger


def test_get_prefect_or_default_logger():
"""Tests `get_prefect_or_default_logger`."""
assert get_prefect_or_default_logger().__class__ == logging.RootLogger
assert logging.getLogger("not root").__class__ == logging.Logger
assert (
get_prefect_or_default_logger(logging.Logger("not root")).__class__
== logging.Logger
)

0 comments on commit 4a03350

Please sign in to comment.