Skip to content

Commit

Permalink
Add generic typing to multicore_utils.py (#2147)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 authored Dec 4, 2024
1 parent 2efd754 commit c7cb8b5
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions metaflow/multicore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import time
import metaflow.tracing as tracing

from typing import Any, Callable, Iterable, Iterator, List, Optional
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
NoReturn,
Tuple,
TypeVar,
Union,
)

try:
# Python 2
Expand All @@ -30,7 +41,13 @@ class MulticoreException(Exception):
pass


def _spawn(func, arg, dir):
_A = TypeVar("_A")
_R = TypeVar("_R")


def _spawn(
func: Callable[[_A], _R], arg: _A, dir: Optional[str]
) -> Union[Tuple[int, str], NoReturn]:
with NamedTemporaryFile(prefix="parallel_map_", dir=dir, delete=False) as tmpfile:
output_file = tmpfile.name

Expand Down Expand Up @@ -63,11 +80,11 @@ def _spawn(func, arg, dir):


def parallel_imap_unordered(
func: Callable[[Any], Any],
iterable: Iterable[Any],
func: Callable[[_A], _R],
iterable: Iterable[_A],
max_parallel: Optional[int] = None,
dir: Optional[str] = None,
) -> Iterator[Any]:
) -> Iterator[_R]:
"""
Parallelizes execution of a function using multiprocessing. The result
order is not guaranteed.
Expand All @@ -79,9 +96,9 @@ def parallel_imap_unordered(
iterable : Iterable[Any]
Iterable over arguments to pass to fun
max_parallel int, optional, default None
Maximum parallelism. If not specified, uses the number of CPUs
Maximum parallelism. If not specified, it uses the number of CPUs
dir : str, optional, default None
If specified, directory where temporary files are created
If specified, it's the directory where temporary files are created
Yields
------
Expand Down Expand Up @@ -121,14 +138,14 @@ def parallel_imap_unordered(


def parallel_map(
func: Callable[[Any], Any],
iterable: Iterable[Any],
func: Callable[[_A], _R],
iterable: Iterable[_A],
max_parallel: Optional[int] = None,
dir: Optional[str] = None,
) -> List[Any]:
) -> List[_R]:
"""
Parallelizes execution of a function using multiprocessing. The result
order is that of the arguments in `iterable`
order is that of the arguments in `iterable`.
Parameters
----------
Expand All @@ -137,9 +154,9 @@ def parallel_map(
iterable : Iterable[Any]
Iterable over arguments to pass to fun
max_parallel int, optional, default None
Maximum parallelism. If not specified, uses the number of CPUs
Maximum parallelism. If not specified, it uses the number of CPUs
dir : str, optional, default None
If specified, directory where temporary files are created
If specified, it's the directory where temporary files are created
Returns
-------
Expand All @@ -155,4 +172,4 @@ def wrapper(arg_with_idx):
res = parallel_imap_unordered(
wrapper, enumerate(iterable), max_parallel=max_parallel, dir=dir
)
return [r for idx, r in sorted(res)]
return [r for _, r in sorted(res)]

0 comments on commit c7cb8b5

Please sign in to comment.