Skip to content

Commit

Permalink
Add generic typing to multicore_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bryant1410 authored Nov 19, 2024
1 parent 05f9756 commit 8b039ea
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions metaflow/multicore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import metaflow.tracing as tracing

from typing import Any, Callable, Iterable, Iterator, List, Optional
from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, TypeVar

try:
# Python 2
Expand All @@ -30,7 +30,11 @@ class MulticoreException(Exception):
pass


def _spawn(func, arg, dir):
_A = TypeVar("_A")
_R = TypeVar("_R")


def _spawn(func: Callable[[_A], _R], arg: _A, dir: Optional[str]) -> Tuple[int, str]:
with NamedTemporaryFile(prefix="parallel_map_", dir=dir, delete=False) as tmpfile:
output_file = tmpfile.name

Expand Down Expand Up @@ -63,11 +67,11 @@ def _spawn(func, arg, dir):


def parallel_imap_unordered(
func: Callable[[Any], Any],
iterable: Iterable[Any],
func: Callable[[_A], _R],
iterable: Iterable[_A],
max_parallel: Optional[int] = None,
dir: Optional[str] = None,
) -> Iterator[Any]:
) -> Iterator[_R]:
"""
Parallelizes execution of a function using multiprocessing. The result
order is not guaranteed.
Expand All @@ -79,9 +83,9 @@ def parallel_imap_unordered(
iterable : Iterable[Any]
Iterable over arguments to pass to fun
max_parallel int, optional, default None
Maximum parallelism. If not specified, uses the number of CPUs
Maximum parallelism. If not specified, it uses the number of CPUs
dir : str, optional, default None
If specified, directory where temporary files are created
If specified, it's the directory where temporary files are created
Yields
------
Expand Down Expand Up @@ -121,14 +125,14 @@ def parallel_imap_unordered(


def parallel_map(
func: Callable[[Any], Any],
iterable: Iterable[Any],
func: Callable[[_A], _R],
iterable: Iterable[_A],
max_parallel: Optional[int] = None,
dir: Optional[str] = None,
) -> List[Any]:
) -> List[_R]:
"""
Parallelizes execution of a function using multiprocessing. The result
order is that of the arguments in `iterable`
order is that of the arguments in `iterable`.
Parameters
----------
Expand All @@ -137,9 +141,9 @@ def parallel_map(
iterable : Iterable[Any]
Iterable over arguments to pass to fun
max_parallel int, optional, default None
Maximum parallelism. If not specified, uses the number of CPUs
Maximum parallelism. If not specified, it uses the number of CPUs
dir : str, optional, default None
If specified, directory where temporary files are created
If specified, it's the directory where temporary files are created
Returns
-------
Expand All @@ -155,4 +159,4 @@ def wrapper(arg_with_idx):
res = parallel_imap_unordered(
wrapper, enumerate(iterable), max_parallel=max_parallel, dir=dir
)
return [r for idx, r in sorted(res)]
return [r for _, r in sorted(res)]

0 comments on commit 8b039ea

Please sign in to comment.