-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d111100
commit 4a03350
Showing
6 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
""" | ||
Tools to improve Prefect concurrently. | ||
""" | ||
from typing import Any, TypeVar | ||
|
||
from prefect.futures import PrefectFuture | ||
from prefect.tasks import Task | ||
from prefect.utilities.callables import get_call_parameters | ||
from typing_extensions import ParamSpec | ||
|
||
from . import logging, states | ||
|
||
T = TypeVar("T") # Generic type var for capturing the inner return type of async funcs | ||
R = TypeVar("R") # The return type of the user's function | ||
P = ParamSpec("P") # The parameters of the task | ||
|
||
|
||
class BatchTask: | ||
"""Wraps a `Task` to perform `Task.map` in batches.""" | ||
|
||
def __init__(self, task: Task[P, R], size: int): | ||
"""Create a `BatchTask` to wrap a `Task` and perform `Task.map` in batches. | ||
Parameters | ||
---------- | ||
task : Task[P, R] | ||
The task to wrap. | ||
size : int | ||
The size of the batches to perform `Task.map` on. | ||
Examples | ||
-------- | ||
>>> from prefect import task | ||
>>> from prefecto.concurrency import BatchTask | ||
>>> @task | ||
... def add(a, b): | ||
... return a + b | ||
>>> batch_add = BatchTask(add, 3) | ||
>>> batch_add.map([1,2,3,4,5], [2,3,4,5,6]) | ||
[ | ||
PrefectFuture<Success: 3>, | ||
PrefectFuture<Success: 5>, | ||
PrefectFuture<Success: 7>, | ||
PrefectFuture<Success: 9>, | ||
PrefectFuture<Success: 11> | ||
] | ||
""" | ||
self.task: Task = task | ||
self.size: int = size | ||
|
||
def _make_batches(self, **params) -> list[dict[str, list[Any]]]: | ||
"""Create batches of arguments to pass to the `Task.map` calls. | ||
Parameters | ||
---------- | ||
**params | ||
Keyword arguments where each value is an iterable of equal length. Should | ||
be at least one keyword argument. | ||
Returns | ||
------- | ||
list[dict[str, list[Any]]] | ||
A list of dictionaries where each dictionary has the same keys as the | ||
provided keyword arguments. The values of the dictionaries are lists with | ||
lengths no greater than `BatchTask.size`. | ||
Examples | ||
-------- | ||
>>> BatchTask(task, 3)._make_batches(a=[1,2,3,4,5], b=[2,3,4,5,6]) | ||
[ | ||
{"a": [1,2,3], "b": [2,3,4]}, | ||
{"a": [4,5], "b": [4,5,6]} | ||
] | ||
""" | ||
parameters = sorted(params.keys()) | ||
if len(parameters) == 0: | ||
raise ValueError("Must provide at least one iterable.") | ||
|
||
# Validate all are iterables | ||
for k in parameters: | ||
if not hasattr(params[k], "__iter__"): | ||
raise ValueError(f"Expected '{k}' to be an iterable.") | ||
|
||
length = len(params[parameters[0]]) | ||
|
||
# Assure all of equal length | ||
if len(parameters) > 1: | ||
for k in parameters[1:]: | ||
if not len(params[k]) == length: | ||
raise ValueError( | ||
f"Expected all iterables to be of length {length} like " | ||
f"'{parameters[0]}'. '{k}' is length {len(params[k])}." | ||
) | ||
|
||
batches = [] | ||
|
||
for i in range(length // self.size): | ||
batch = {p: [] for p in parameters} | ||
for p in parameters: | ||
batch[p] = params[p][i * self.size : (i + 1) * self.size] | ||
batches.append(batch) | ||
|
||
# Add the remainder if there is one | ||
if length % self.size != 0: | ||
batch = {p: [] for p in parameters} | ||
for p in parameters: | ||
batch[p] = params[p][(i + 1) * self.size :] | ||
batches.append(batch) | ||
|
||
return batches | ||
|
||
def map(self, *args, **kwds) -> list[PrefectFuture]: | ||
"""Perform a `Task.map` operation in batches of the keyword arguments. The | ||
arguments must be iterables of equal length. | ||
Parameters | ||
---------- | ||
*args | ||
Positional arguments to pass to the task. | ||
**kwds | ||
Keyword arguments to pass to the task. | ||
Returns | ||
------- | ||
list[PrefectFuture] | ||
A list of futures for each batch. | ||
""" | ||
parameters = get_call_parameters(self.task.fn, args, kwds, apply_defaults=False) | ||
batches = self._make_batches(**parameters) | ||
|
||
return self._map(batches) | ||
|
||
def _map(self, batches: list[dict[str, list[Any]]]) -> list[PrefectFuture]: | ||
"""Applies `Task.map` to each batch. | ||
Args: | ||
batches (list[dict[str, list[Any]]]): _description_ | ||
Returns: | ||
list[PrefectFuture]: _description_ | ||
""" | ||
logger = logging.get_prefect_or_default_logger() | ||
results: list[PrefectFuture] = [] | ||
for i, batch in enumerate(batches[:-1]): | ||
logger.debug(f"Mapping {self.task.name} batch {i+1} of {len(batches)}.") | ||
# Map the batch | ||
futures = self.task.map(**batch) | ||
results.extend(futures) | ||
# Poll futures to ensure they are not active. | ||
is_processing: bool = False | ||
while is_processing: | ||
for f in futures: | ||
if not states.is_terminal(f.get_state()): | ||
# If any future is still processing, break and poll again. | ||
is_processing = True | ||
break | ||
else: | ||
is_processing = False | ||
|
||
# Map the last batch | ||
logger.debug( | ||
f"Mapping {self.task.name} batch {len(batches)} of {len(batches)}." | ||
) | ||
results.extend(self.task.map(**batches[-1])) | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Prefect logging utilities. | ||
""" | ||
import logging | ||
|
||
from prefect.logging import get_run_logger | ||
|
||
|
||
def get_prefect_or_default_logger( | ||
__default: logging.Logger | str | None = None, | ||
) -> logging.Logger: | ||
"""Gets the Prefect logger if the global context is set. Returns the `__default` or | ||
root logger if not. | ||
""" | ||
# Type check the default logger | ||
if not isinstance(__default, (logging.Logger, str, type(None))): | ||
raise TypeError( | ||
f"Expected `__default` to be a `logging.Logger`, `str`, or `None`, " | ||
f"got `{type(__default).__name__}`." | ||
) | ||
try: | ||
return get_run_logger() | ||
except RuntimeError: | ||
if isinstance(__default, str): | ||
return logging.getLogger(__default) | ||
return __default or logging.getLogger() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
Tools to improve Prefect states. | ||
""" | ||
from prefect import states | ||
|
||
|
||
def is_terminal(state: states.State) -> bool: | ||
"""Return True if the state is terminal. Terminal states are: | ||
- Cancelled | ||
- Completed | ||
- Crashed | ||
- Failed | ||
""" | ||
TERMINALS = [ | ||
state.is_cancelled, | ||
state.is_completed, | ||
state.is_crashed, | ||
state.is_failed, | ||
] | ||
for terminal in TERMINALS: | ||
if terminal(): | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
Unit tests for the `concurrency` module. | ||
""" | ||
from prefect import flow, task | ||
|
||
from prefecto.concurrency import BatchTask | ||
|
||
|
||
@task | ||
def add(a, b): | ||
"""Add two numbers.""" | ||
return a + b | ||
|
||
|
||
class TestBatchTask: | ||
"""Unit tests for `BatchTask`.""" | ||
|
||
def test_make_batches(self): | ||
"""Test `_make_batches`.""" | ||
batches = BatchTask(add, 3)._make_batches(a=[1, 2, 3, 4, 5], b=[2, 3, 4, 5, 6]) | ||
assert batches == [{"a": [1, 2, 3], "b": [2, 3, 4]}, {"a": [4, 5], "b": [5, 6]}] | ||
|
||
def test_map(self, harness): | ||
"""Test `BatchTask.map`.""" | ||
|
||
@task | ||
def realize(futures: list[int]): | ||
"""Converts futures to their values.""" | ||
return futures | ||
|
||
@flow | ||
def test() -> list[int]: | ||
"""Test flow.""" | ||
futures = BatchTask(add, 3).map([1, 2, 3, 4, 5], [2, 3, 4, 5, 6]) | ||
return realize(futures) | ||
|
||
result = test() | ||
assert result == [3, 5, 7, 9, 11] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Tests for the logging module. | ||
""" | ||
import logging | ||
|
||
from prefecto.logging import get_prefect_or_default_logger | ||
|
||
|
||
def test_get_prefect_or_default_logger(): | ||
"""Tests `get_prefect_or_default_logger`.""" | ||
assert get_prefect_or_default_logger().__class__ == logging.RootLogger | ||
assert logging.getLogger("not root").__class__ == logging.Logger | ||
assert ( | ||
get_prefect_or_default_logger(logging.Logger("not root")).__class__ | ||
== logging.Logger | ||
) |