Skip to content

Commit

Permalink
Move parallel module outside utils
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito committed Sep 3, 2023
1 parent e5c117a commit 433e004
Show file tree
Hide file tree
Showing 28 changed files with 132 additions and 98 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This module provides a common interface to parallelization backends. The list of
supported backends is [here][pydvl.utils.parallel.backends]. Backends can be
supported backends is [here][pydvl.parallel.backends]. Backends can be
selected with the `backend` argument of an instance of
[ParallelConfig][pydvl.utils.config.ParallelConfig], as seen in the examples
below.
Expand All @@ -9,8 +9,7 @@
basic high-level pattern is
```python
from pydvl.utils.parallel import init_executo
from pydvl.utils.config import ParallelConfig
from pydvl.parallel import init_executor, ParallelConfig
config = ParallelConfig(backend="ray")
with init_executor(max_workers=1, config=config) as executor:
Expand All @@ -22,8 +21,7 @@
Running a map-reduce job is also easy:
```python
from pydvl.utils.parallel import init_executor
from pydvl.utils.config import ParallelConfig
from pydvl.parallel import init_executor, ParallelConfig
config = ParallelConfig(backend="joblib")
with init_executor(config=config) as executor:
Expand All @@ -32,11 +30,14 @@
```
There is an alternative map-reduce implementation
[MapReduceJob][pydvl.utils.parallel.map_reduce.MapReduceJob] which internally
[MapReduceJob][pydvl.parallel.map_reduce.MapReduceJob] which internally
uses joblib's higher level API with `Parallel()`
"""
# HACK to avoid circular imports
from ..utils.types import * # pylint: disable=wrong-import-order
from .backend import *
from .backends import *
from .config import *
from .futures import *
from .map_reduce import *

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from abc import abstractmethod
from concurrent.futures import Executor
from enum import Flag, auto
from typing import Any, Callable, Type, TypeVar
from typing import Any, Callable, Type

from ..config import ParallelConfig
from ..types import NoPublicConstructor
from ..utils.types import NoPublicConstructor
from .config import ParallelConfig

__all__ = [
"init_parallel_backend",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
from __future__ import annotations

import logging
from concurrent.futures import Executor
from typing import Callable, TypeVar, cast

import joblib
from joblib import delayed
from joblib.externals.loky import get_reusable_executor

from pydvl.utils import ParallelConfig
from pydvl.utils.parallel.backend import BaseParallelBackend, CancellationPolicy, log
from pydvl.parallel.backend import BaseParallelBackend, CancellationPolicy
from pydvl.parallel.config import ParallelConfig

__all__ = ["JoblibParallelBackend"]

T = TypeVar("T")

log = logging.getLogger(__name__)


class JoblibParallelBackend(BaseParallelBackend, backend_name="joblib"):
"""Class used to wrap joblib to make it transparent to algorithms.
It shouldn't be initialized directly. You should instead call
[init_parallel_backend()][pydvl.utils.parallel.backend.init_parallel_backend].
[init_parallel_backend()][pydvl.parallel.backend.init_parallel_backend].
Args:
config: instance of [ParallelConfig][pydvl.utils.config.ParallelConfig]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ray import ObjectRef
from ray.util.joblib import register_ray

from pydvl.utils import ParallelConfig
from pydvl.utils.parallel.backend import BaseParallelBackend, CancellationPolicy
from pydvl.parallel.backend import BaseParallelBackend, CancellationPolicy
from pydvl.parallel.config import ParallelConfig

__all__ = ["RayParallelBackend"]

Expand All @@ -20,7 +20,7 @@ class RayParallelBackend(BaseParallelBackend, backend_name="ray"):
"""Class used to wrap ray to make it transparent to algorithms.
It shouldn't be initialized directly. You should instead call
[init_parallel_backend()][pydvl.utils.parallel.backend.init_parallel_backend].
[init_parallel_backend()][pydvl.parallel.backend.init_parallel_backend].
Args:
config: instance of [ParallelConfig][pydvl.utils.config.ParallelConfig]
Expand All @@ -43,7 +43,7 @@ def executor(
config: ParallelConfig = ParallelConfig(),
cancel_futures: CancellationPolicy = CancellationPolicy.PENDING,
) -> Executor:
from pydvl.utils.parallel.futures.ray import RayExecutor
from pydvl.parallel.futures.ray import RayExecutor

return RayExecutor(max_workers, config=config, cancel_futures=cancel_futures) # type: ignore

Expand Down
30 changes: 30 additions & 0 deletions src/pydvl/parallel/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging
from dataclasses import dataclass
from typing import Literal, Optional, Tuple, Union

__all__ = ["ParallelConfig"]


@dataclass(frozen=True)
class ParallelConfig:
"""Configuration for parallel computation backend.
Args:
backend: Type of backend to use. Defaults to 'joblib'
address: Address of existing remote or local cluster to use.
n_cpus_local: Number of CPUs to use when creating a local ray cluster.
This has no effect when using an existing ray cluster.
logging_level: Logging level for the parallel backend's worker.
wait_timeout: Timeout in seconds for waiting on futures.
"""

backend: Literal["joblib", "ray"] = "joblib"
address: Optional[Union[str, Tuple[str, int]]] = None
n_cpus_local: Optional[int] = None
logging_level: int = logging.WARNING
wait_timeout: float = 1.0

def __post_init__(self) -> None:
# FIXME: this is specific to ray
if self.address is not None and self.n_cpus_local is not None:
raise ValueError("When `address` is set, `n_cpus_local` should be None.")
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from contextlib import contextmanager
from typing import Generator, Optional

from pydvl.utils.config import ParallelConfig
from pydvl.utils.parallel.backend import BaseParallelBackend
from pydvl.parallel.backend import BaseParallelBackend
from pydvl.parallel.config import ParallelConfig

try:
from pydvl.utils.parallel.futures.ray import RayExecutor
from pydvl.parallel.futures.ray import RayExecutor
except ImportError:
pass

Expand All @@ -30,16 +30,16 @@ def init_executor(
??? Examples
``` python
from pydvl.utils.parallel.futures import init_executor
from pydvl.utils.config import ParallelConfig
from pydvl.parallel import init_executor, ParallelConfig
config = ParallelConfig(backend="ray")
with init_executor(max_workers=1, config=config) as executor:
future = executor.submit(lambda x: x + 1, 1)
result = future.result()
assert result == 2
```
``` python
from pydvl.utils.parallel.futures import init_executor
from pydvl.parallel.futures import init_executor
with init_executor() as executor:
results = list(executor.map(lambda x: x + 1, range(5)))
assert results == [1, 2, 3, 4, 5]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import ray
from deprecate import deprecated

from pydvl.utils import ParallelConfig
from pydvl.parallel.config import ParallelConfig

__all__ = ["RayExecutor"]

from pydvl.utils.parallel import CancellationPolicy
from pydvl.parallel import CancellationPolicy

T = TypeVar("T")

Expand All @@ -26,7 +26,7 @@ class RayExecutor(Executor):
"""Asynchronous executor using Ray that implements the concurrent.futures API.
It shouldn't be initialized directly. You should instead call
[init_executor()][pydvl.utils.parallel.futures.init_executor].
[init_executor()][pydvl.parallel.futures.init_executor].
Args:
max_workers: Maximum number of concurrent tasks. Each task can request
Expand All @@ -41,7 +41,7 @@ class RayExecutor(Executor):
pending futures, but not running ones, as done by
[concurrent.futures.ProcessPoolExecutor][]. Additionally, `All`
cancels all pending and running futures, and `None` doesn't cancel
any. See [CancellationPolicy][pydvl.utils.parallel.backend.CancellationPolicy]
any. See [CancellationPolicy][pydvl.parallel.backend.CancellationPolicy]
"""

@deprecated(
Expand Down Expand Up @@ -145,7 +145,7 @@ def shutdown(
This method tries to mimic the behaviour of
[Executor.shutdown][concurrent.futures.Executor.shutdown]
while allowing one more value for ``cancel_futures`` which instructs it
to use the [CancellationPolicy][pydvl.utils.parallel.backend.CancellationPolicy]
to use the [CancellationPolicy][pydvl.parallel.backend.CancellationPolicy]
defined upon construction.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from numpy.random import SeedSequence
from numpy.typing import NDArray

from ..config import ParallelConfig
from ..functional import maybe_add_argument
from ..types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence
from ..utils.functional import maybe_add_argument
from ..utils.types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence
from .backend import init_parallel_backend
from .config import ParallelConfig

__all__ = ["MapReduceJob"]

Expand Down Expand Up @@ -54,7 +54,7 @@ class MapReduceJob(Generic[T, R]):
A simple usage example with 2 jobs:
``` pycon
>>> from pydvl.utils.parallel import MapReduceJob
>>> from pydvl.parallel import MapReduceJob
>>> import numpy as np
>>> map_reduce_job: MapReduceJob[np.ndarray, np.ndarray] = MapReduceJob(
... np.arange(5),
Expand All @@ -68,7 +68,7 @@ class MapReduceJob(Generic[T, R]):
When passed a single object as input, it will be repeated for each job:
``` pycon
>>> from pydvl.utils.parallel import MapReduceJob
>>> from pydvl.parallel import MapReduceJob
>>> import numpy as np
>>> map_reduce_job: MapReduceJob[int, np.ndarray] = MapReduceJob(
... 5,
Expand Down
31 changes: 4 additions & 27 deletions src/pydvl/utils/config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,14 @@
import logging
from dataclasses import dataclass, field
from typing import Iterable, Literal, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple

from pymemcache.serde import PickleSerde

PICKLE_VERSION = 5 # python >= 3.8

__all__ = ["ParallelConfig", "MemcachedClientConfig", "MemcachedConfig"]


@dataclass(frozen=True)
class ParallelConfig:
"""Configuration for parallel computation backend.
from pydvl.parallel.config import ParallelConfig

Args:
backend: Type of backend to use. Defaults to 'joblib'
address: Address of existing remote or local cluster to use.
n_cpus_local: Number of CPUs to use when creating a local ray cluster.
This has no effect when using an existing ray cluster.
logging_level: Logging level for the parallel backend's worker.
wait_timeout: Timeout in seconds for waiting on futures.
"""
PICKLE_VERSION = 5 # python >= 3.8

backend: Literal["joblib", "ray"] = "joblib"
address: Optional[Union[str, Tuple[str, int]]] = None
n_cpus_local: Optional[int] = None
logging_level: int = logging.WARNING
wait_timeout: float = 1.0

def __post_init__(self) -> None:
# FIXME: this is specific to ray
if self.address is not None and self.n_cpus_local is not None:
raise ValueError("When `address` is set, `n_cpus_local` should be None.")
__all__ = ["MemcachedClientConfig", "MemcachedConfig", "ParallelConfig"]


@dataclass(frozen=True)
Expand Down
22 changes: 22 additions & 0 deletions src/pydvl/utils/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
# This module is deprecated
!!! warning "Redirects"
Imports from this module will be redirected to
[pydvl.parallel][pydvl.parallel] only until v0.9.0. Please update your
imports.
"""
import logging

from ..parallel.backend import *
from ..parallel.config import *
from ..parallel.futures import *
from ..parallel.map_reduce import *

log = logging.getLogger(__name__)

log.warning(
"Importing parallel tools from pydvl.utils is deprecated. "
"Please import directly from pydvl.parallel. "
"Redirected imports will be removed in v0.9.0"
)
3 changes: 2 additions & 1 deletion src/pydvl/value/least_core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy as np
from numpy.typing import NDArray

from pydvl.utils import MapReduceJob, ParallelConfig, Status, Utility
from pydvl.parallel import MapReduceJob, ParallelConfig
from pydvl.utils import Status, Utility
from pydvl.value import ValuationResult

__all__ = [
Expand Down
4 changes: 1 addition & 3 deletions src/pydvl/value/least_core/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import numpy as np
from numpy.typing import NDArray

from pydvl.utils.config import ParallelConfig
from pydvl.parallel import MapReduceJob, ParallelConfig, effective_n_jobs
from pydvl.utils.numeric import random_powerset
from pydvl.utils.parallel import MapReduceJob
from pydvl.utils.parallel.backend import effective_n_jobs
from pydvl.utils.progress import maybe_progress
from pydvl.utils.utility import Utility
from pydvl.value.least_core.common import LeastCoreProblem, lc_solve_problem
Expand Down
3 changes: 2 additions & 1 deletion src/pydvl/value/loo/loo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from tqdm import tqdm

from pydvl.utils import ParallelConfig, Utility, effective_n_jobs, init_executor
from pydvl.parallel import ParallelConfig, effective_n_jobs, init_executor
from pydvl.utils import Utility
from pydvl.value.result import ValuationResult

__all__ = ["compute_loo"]
Expand Down
3 changes: 2 additions & 1 deletion src/pydvl/value/semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@
from deprecate import deprecated
from tqdm import tqdm

from pydvl.utils import ParallelConfig, Utility
from pydvl.parallel.config import ParallelConfig
from pydvl.utils import Utility
from pydvl.utils.types import Seed
from pydvl.value import ValuationResult
from pydvl.value.sampler import (
Expand Down
6 changes: 3 additions & 3 deletions src/pydvl/value/shapley/gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
## References
[^1]: <a name="jia_efficient_2019"></a>Jia, R. et al., 2019.
[Towards Efficient Data Valuation Based on the Shapley Value](http://proceedings.mlr.press/v89/jia19a.html).
[Towards Efficient Data Valuation Based on the Shapley Value](https://proceedings.mlr.press/v89/jia19a.html).
In: Proceedings of the 22nd International Conference on Artificial Intelligence and Statistics, pp. 1167–1176. PMLR.
"""
import logging
Expand All @@ -29,9 +29,9 @@
from numpy.random import SeedSequence
from numpy.typing import NDArray

from pydvl.utils import MapReduceJob, ParallelConfig, Utility, maybe_progress
from pydvl.parallel import MapReduceJob, ParallelConfig, effective_n_jobs
from pydvl.utils import Utility, maybe_progress
from pydvl.utils.numeric import random_subset_of_size
from pydvl.utils.parallel.backend import effective_n_jobs
from pydvl.utils.status import Status
from pydvl.utils.types import Seed, ensure_seed_sequence
from pydvl.value import ValuationResult
Expand Down
Loading

0 comments on commit 433e004

Please sign in to comment.