From 8b039ea66c873bcf5a004b74e161294d26d409fb Mon Sep 17 00:00:00 2001 From: Santiago Castro Date: Tue, 19 Nov 2024 15:15:10 -0800 Subject: [PATCH] Add generic typing to multicore_utils.py --- metaflow/multicore_utils.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/metaflow/multicore_utils.py b/metaflow/multicore_utils.py index 202783aceca..28e14df6568 100644 --- a/metaflow/multicore_utils.py +++ b/metaflow/multicore_utils.py @@ -6,7 +6,7 @@ import time import metaflow.tracing as tracing -from typing import Any, Callable, Iterable, Iterator, List, Optional +from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, TypeVar try: # Python 2 @@ -30,7 +30,11 @@ class MulticoreException(Exception): pass -def _spawn(func, arg, dir): +_A = TypeVar("_A") +_R = TypeVar("_R") + + +def _spawn(func: Callable[[_A], _R], arg: _A, dir: Optional[str]) -> Tuple[int, str]: with NamedTemporaryFile(prefix="parallel_map_", dir=dir, delete=False) as tmpfile: output_file = tmpfile.name @@ -63,11 +67,11 @@ def _spawn(func, arg, dir): def parallel_imap_unordered( - func: Callable[[Any], Any], - iterable: Iterable[Any], + func: Callable[[_A], _R], + iterable: Iterable[_A], max_parallel: Optional[int] = None, dir: Optional[str] = None, -) -> Iterator[Any]: +) -> Iterator[_R]: """ Parallelizes execution of a function using multiprocessing. The result order is not guaranteed. @@ -79,9 +83,9 @@ def parallel_imap_unordered( iterable : Iterable[Any] Iterable over arguments to pass to fun max_parallel int, optional, default None - Maximum parallelism. If not specified, uses the number of CPUs + Maximum parallelism. If not specified, it uses the number of CPUs dir : str, optional, default None - If specified, directory where temporary files are created + If specified, it's the directory where temporary files are created Yields ------ @@ -121,14 +125,14 @@ def parallel_imap_unordered( def parallel_map( - func: Callable[[Any], Any], - iterable: Iterable[Any], + func: Callable[[_A], _R], + iterable: Iterable[_A], max_parallel: Optional[int] = None, dir: Optional[str] = None, -) -> List[Any]: +) -> List[_R]: """ Parallelizes execution of a function using multiprocessing. The result - order is that of the arguments in `iterable` + order is that of the arguments in `iterable`. Parameters ---------- @@ -137,9 +141,9 @@ def parallel_map( iterable : Iterable[Any] Iterable over arguments to pass to fun max_parallel int, optional, default None - Maximum parallelism. If not specified, uses the number of CPUs + Maximum parallelism. If not specified, it uses the number of CPUs dir : str, optional, default None - If specified, directory where temporary files are created + If specified, it's the directory where temporary files are created Returns ------- @@ -155,4 +159,4 @@ def wrapper(arg_with_idx): res = parallel_imap_unordered( wrapper, enumerate(iterable), max_parallel=max_parallel, dir=dir ) - return [r for idx, r in sorted(res)] + return [r for _, r in sorted(res)]