diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index bb8f5169e..1a5930308 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -1,4 +1,4 @@ -name: Run tests, build docs, publish to TestPyPI +name: Tests, docs, package on: push: diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 433da896d..bd15a7663 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -16,6 +16,8 @@ on: required: true env: + GITHUB_BOT_USERNAME: github-actions[bot] + GITHUB_BOT_EMAIL: 41898282+github-actions[bot]@users.noreply.github.com PY_COLORS: 1 jobs: diff --git a/CHANGELOG.md b/CHANGELOG.md index bc82e515b..499931010 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +- Refactoring of parallel module. Old imports will stop working in v0.9.0 + [PR #421](https://github.com/aai-institute/pyDVL/pull/421) + ## 0.7.0 - 📚🆕 Documentation and IF overhaul, new methods and bug fixes 💥🐞 This is our first β release! We have worked hard to deliver improvements across diff --git a/README.md b/README.md index 3950e2272..25ebe6c7f 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,9 @@ model. We implement methods from the following papers: - Naman Agarwal, Brian Bullins, and Elad Hazan, [Second-Order Stochastic Optimization for Machine Learning in Linear Time](https://www.jmlr.org/papers/v18/16-491.html), Journal of Machine Learning Research 18 (2017): 1-40. +- Schioppa, Andrea, Polina Zablotskaia, David Vilar, and Artem Sokolov. + [Scaling Up Influence Functions](http://arxiv.org/abs/2112.03052). + In Proceedings of the AAAI-22. arXiv, 2021. # Installation diff --git a/src/pydvl/utils/parallel/__init__.py b/src/pydvl/parallel/__init__.py similarity index 74% rename from src/pydvl/utils/parallel/__init__.py rename to src/pydvl/parallel/__init__.py index 76319b197..c8feca005 100644 --- a/src/pydvl/utils/parallel/__init__.py +++ b/src/pydvl/parallel/__init__.py @@ -1,6 +1,6 @@ """ This module provides a common interface to parallelization backends. The list of -supported backends is [here][pydvl.utils.parallel.backends]. Backends can be +supported backends is [here][pydvl.parallel.backends]. Backends can be selected with the `backend` argument of an instance of [ParallelConfig][pydvl.utils.config.ParallelConfig], as seen in the examples below. @@ -9,8 +9,7 @@ basic high-level pattern is ```python -from pydvl.utils.parallel import init_executo -from pydvl.utils.config import ParallelConfig +from pydvl.parallel import init_executor, ParallelConfig config = ParallelConfig(backend="ray") with init_executor(max_workers=1, config=config) as executor: @@ -22,8 +21,7 @@ Running a map-reduce job is also easy: ```python -from pydvl.utils.parallel import init_executor -from pydvl.utils.config import ParallelConfig +from pydvl.parallel import init_executor, ParallelConfig config = ParallelConfig(backend="joblib") with init_executor(config=config) as executor: @@ -32,11 +30,14 @@ ``` There is an alternative map-reduce implementation -[MapReduceJob][pydvl.utils.parallel.map_reduce.MapReduceJob] which internally +[MapReduceJob][pydvl.parallel.map_reduce.MapReduceJob] which internally uses joblib's higher level API with `Parallel()` """ +# HACK to avoid circular imports +from ..utils.types import * # pylint: disable=wrong-import-order from .backend import * from .backends import * +from .config import * from .futures import * from .map_reduce import * diff --git a/src/pydvl/utils/parallel/backend.py b/src/pydvl/parallel/backend.py similarity index 97% rename from src/pydvl/utils/parallel/backend.py rename to src/pydvl/parallel/backend.py index 0191c7be6..84f885a7b 100644 --- a/src/pydvl/utils/parallel/backend.py +++ b/src/pydvl/parallel/backend.py @@ -5,10 +5,10 @@ from abc import abstractmethod from concurrent.futures import Executor from enum import Flag, auto -from typing import Any, Callable, Type, TypeVar +from typing import Any, Callable, Type -from ..config import ParallelConfig -from ..types import NoPublicConstructor +from ..utils.types import NoPublicConstructor +from .config import ParallelConfig __all__ = [ "init_parallel_backend", diff --git a/src/pydvl/utils/parallel/backends/__init__.py b/src/pydvl/parallel/backends/__init__.py similarity index 100% rename from src/pydvl/utils/parallel/backends/__init__.py rename to src/pydvl/parallel/backends/__init__.py diff --git a/src/pydvl/utils/parallel/backends/joblib.py b/src/pydvl/parallel/backends/joblib.py similarity index 89% rename from src/pydvl/utils/parallel/backends/joblib.py rename to src/pydvl/parallel/backends/joblib.py index c75618fbf..c78ea8b90 100644 --- a/src/pydvl/utils/parallel/backends/joblib.py +++ b/src/pydvl/parallel/backends/joblib.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from concurrent.futures import Executor from typing import Callable, TypeVar, cast @@ -7,19 +8,21 @@ from joblib import delayed from joblib.externals.loky import get_reusable_executor -from pydvl.utils import ParallelConfig -from pydvl.utils.parallel.backend import BaseParallelBackend, CancellationPolicy, log +from pydvl.parallel.backend import BaseParallelBackend, CancellationPolicy +from pydvl.parallel.config import ParallelConfig __all__ = ["JoblibParallelBackend"] T = TypeVar("T") +log = logging.getLogger(__name__) + class JoblibParallelBackend(BaseParallelBackend, backend_name="joblib"): """Class used to wrap joblib to make it transparent to algorithms. It shouldn't be initialized directly. You should instead call - [init_parallel_backend()][pydvl.utils.parallel.backend.init_parallel_backend]. + [init_parallel_backend()][pydvl.parallel.backend.init_parallel_backend]. Args: config: instance of [ParallelConfig][pydvl.utils.config.ParallelConfig] diff --git a/src/pydvl/utils/parallel/backends/ray.py b/src/pydvl/parallel/backends/ray.py similarity index 91% rename from src/pydvl/utils/parallel/backends/ray.py rename to src/pydvl/parallel/backends/ray.py index a0f0f6603..7e078fde1 100644 --- a/src/pydvl/utils/parallel/backends/ray.py +++ b/src/pydvl/parallel/backends/ray.py @@ -7,8 +7,8 @@ from ray import ObjectRef from ray.util.joblib import register_ray -from pydvl.utils import ParallelConfig -from pydvl.utils.parallel.backend import BaseParallelBackend, CancellationPolicy +from pydvl.parallel.backend import BaseParallelBackend, CancellationPolicy +from pydvl.parallel.config import ParallelConfig __all__ = ["RayParallelBackend"] @@ -20,7 +20,7 @@ class RayParallelBackend(BaseParallelBackend, backend_name="ray"): """Class used to wrap ray to make it transparent to algorithms. It shouldn't be initialized directly. You should instead call - [init_parallel_backend()][pydvl.utils.parallel.backend.init_parallel_backend]. + [init_parallel_backend()][pydvl.parallel.backend.init_parallel_backend]. Args: config: instance of [ParallelConfig][pydvl.utils.config.ParallelConfig] @@ -43,7 +43,7 @@ def executor( config: ParallelConfig = ParallelConfig(), cancel_futures: CancellationPolicy = CancellationPolicy.PENDING, ) -> Executor: - from pydvl.utils.parallel.futures.ray import RayExecutor + from pydvl.parallel.futures.ray import RayExecutor return RayExecutor(max_workers, config=config, cancel_futures=cancel_futures) # type: ignore diff --git a/src/pydvl/parallel/config.py b/src/pydvl/parallel/config.py new file mode 100644 index 000000000..362fde210 --- /dev/null +++ b/src/pydvl/parallel/config.py @@ -0,0 +1,30 @@ +import logging +from dataclasses import dataclass +from typing import Literal, Optional, Tuple, Union + +__all__ = ["ParallelConfig"] + + +@dataclass(frozen=True) +class ParallelConfig: + """Configuration for parallel computation backend. + + Args: + backend: Type of backend to use. Defaults to 'joblib' + address: Address of existing remote or local cluster to use. + n_cpus_local: Number of CPUs to use when creating a local ray cluster. + This has no effect when using an existing ray cluster. + logging_level: Logging level for the parallel backend's worker. + wait_timeout: Timeout in seconds for waiting on futures. + """ + + backend: Literal["joblib", "ray"] = "joblib" + address: Optional[Union[str, Tuple[str, int]]] = None + n_cpus_local: Optional[int] = None + logging_level: int = logging.WARNING + wait_timeout: float = 1.0 + + def __post_init__(self) -> None: + # FIXME: this is specific to ray + if self.address is not None and self.n_cpus_local is not None: + raise ValueError("When `address` is set, `n_cpus_local` should be None.") diff --git a/src/pydvl/utils/parallel/futures/__init__.py b/src/pydvl/parallel/futures/__init__.py similarity index 80% rename from src/pydvl/utils/parallel/futures/__init__.py rename to src/pydvl/parallel/futures/__init__.py index 937eb2a95..c42026ecf 100644 --- a/src/pydvl/utils/parallel/futures/__init__.py +++ b/src/pydvl/parallel/futures/__init__.py @@ -2,11 +2,11 @@ from contextlib import contextmanager from typing import Generator, Optional -from pydvl.utils.config import ParallelConfig -from pydvl.utils.parallel.backend import BaseParallelBackend +from pydvl.parallel.backend import BaseParallelBackend +from pydvl.parallel.config import ParallelConfig try: - from pydvl.utils.parallel.futures.ray import RayExecutor + from pydvl.parallel.futures.ray import RayExecutor except ImportError: pass @@ -30,8 +30,8 @@ def init_executor( ??? Examples ``` python - from pydvl.utils.parallel.futures import init_executor - from pydvl.utils.config import ParallelConfig + from pydvl.parallel import init_executor, ParallelConfig + config = ParallelConfig(backend="ray") with init_executor(max_workers=1, config=config) as executor: future = executor.submit(lambda x: x + 1, 1) @@ -39,7 +39,7 @@ def init_executor( assert result == 2 ``` ``` python - from pydvl.utils.parallel.futures import init_executor + from pydvl.parallel.futures import init_executor with init_executor() as executor: results = list(executor.map(lambda x: x + 1, range(5))) assert results == [1, 2, 3, 4, 5] diff --git a/src/pydvl/utils/parallel/futures/ray.py b/src/pydvl/parallel/futures/ray.py similarity index 97% rename from src/pydvl/utils/parallel/futures/ray.py rename to src/pydvl/parallel/futures/ray.py index 677320396..1a9658744 100644 --- a/src/pydvl/utils/parallel/futures/ray.py +++ b/src/pydvl/parallel/futures/ray.py @@ -11,11 +11,11 @@ import ray from deprecate import deprecated -from pydvl.utils import ParallelConfig +from pydvl.parallel.config import ParallelConfig __all__ = ["RayExecutor"] -from pydvl.utils.parallel import CancellationPolicy +from pydvl.parallel import CancellationPolicy T = TypeVar("T") @@ -26,7 +26,7 @@ class RayExecutor(Executor): """Asynchronous executor using Ray that implements the concurrent.futures API. It shouldn't be initialized directly. You should instead call - [init_executor()][pydvl.utils.parallel.futures.init_executor]. + [init_executor()][pydvl.parallel.futures.init_executor]. Args: max_workers: Maximum number of concurrent tasks. Each task can request @@ -41,7 +41,7 @@ class RayExecutor(Executor): pending futures, but not running ones, as done by [concurrent.futures.ProcessPoolExecutor][]. Additionally, `All` cancels all pending and running futures, and `None` doesn't cancel - any. See [CancellationPolicy][pydvl.utils.parallel.backend.CancellationPolicy] + any. See [CancellationPolicy][pydvl.parallel.backend.CancellationPolicy] """ @deprecated( @@ -145,7 +145,7 @@ def shutdown( This method tries to mimic the behaviour of [Executor.shutdown][concurrent.futures.Executor.shutdown] while allowing one more value for ``cancel_futures`` which instructs it - to use the [CancellationPolicy][pydvl.utils.parallel.backend.CancellationPolicy] + to use the [CancellationPolicy][pydvl.parallel.backend.CancellationPolicy] defined upon construction. Args: diff --git a/src/pydvl/utils/parallel/map_reduce.py b/src/pydvl/parallel/map_reduce.py similarity index 96% rename from src/pydvl/utils/parallel/map_reduce.py rename to src/pydvl/parallel/map_reduce.py index 149cd2752..e24711f24 100644 --- a/src/pydvl/utils/parallel/map_reduce.py +++ b/src/pydvl/parallel/map_reduce.py @@ -14,10 +14,10 @@ from numpy.random import SeedSequence from numpy.typing import NDArray -from ..config import ParallelConfig -from ..functional import maybe_add_argument -from ..types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence +from ..utils.functional import maybe_add_argument +from ..utils.types import MapFunction, ReduceFunction, Seed, ensure_seed_sequence from .backend import init_parallel_backend +from .config import ParallelConfig __all__ = ["MapReduceJob"] @@ -54,7 +54,7 @@ class MapReduceJob(Generic[T, R]): A simple usage example with 2 jobs: ``` pycon - >>> from pydvl.utils.parallel import MapReduceJob + >>> from pydvl.parallel import MapReduceJob >>> import numpy as np >>> map_reduce_job: MapReduceJob[np.ndarray, np.ndarray] = MapReduceJob( ... np.arange(5), @@ -68,7 +68,7 @@ class MapReduceJob(Generic[T, R]): When passed a single object as input, it will be repeated for each job: ``` pycon - >>> from pydvl.utils.parallel import MapReduceJob + >>> from pydvl.parallel import MapReduceJob >>> import numpy as np >>> map_reduce_job: MapReduceJob[int, np.ndarray] = MapReduceJob( ... 5, diff --git a/src/pydvl/utils/config.py b/src/pydvl/utils/config.py index b5a4a6743..6e240bffc 100644 --- a/src/pydvl/utils/config.py +++ b/src/pydvl/utils/config.py @@ -1,37 +1,14 @@ -import logging from dataclasses import dataclass, field -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple from pymemcache.serde import PickleSerde -PICKLE_VERSION = 5 # python >= 3.8 - -__all__ = ["ParallelConfig", "MemcachedClientConfig", "MemcachedConfig"] - - -@dataclass(frozen=True) -class ParallelConfig: - """Configuration for parallel computation backend. +from pydvl.parallel.config import ParallelConfig - Args: - backend: Type of backend to use. Defaults to 'joblib' - address: Address of existing remote or local cluster to use. - n_cpus_local: Number of CPUs to use when creating a local ray cluster. - This has no effect when using an existing ray cluster. - logging_level: Logging level for the parallel backend's worker. - wait_timeout: Timeout in seconds for waiting on futures. - """ +PICKLE_VERSION = 5 # python >= 3.8 - backend: Literal["joblib", "ray"] = "joblib" - address: Optional[Union[str, Tuple[str, int]]] = None - n_cpus_local: Optional[int] = None - logging_level: int = logging.WARNING - wait_timeout: float = 1.0 - def __post_init__(self) -> None: - # FIXME: this is specific to ray - if self.address is not None and self.n_cpus_local is not None: - raise ValueError("When `address` is set, `n_cpus_local` should be None.") +__all__ = ["MemcachedClientConfig", "MemcachedConfig", "ParallelConfig"] @dataclass(frozen=True) diff --git a/src/pydvl/utils/parallel.py b/src/pydvl/utils/parallel.py new file mode 100644 index 000000000..2df52605f --- /dev/null +++ b/src/pydvl/utils/parallel.py @@ -0,0 +1,24 @@ +""" +# This module is deprecated + +!!! warning "Redirects" + Imports from this module will be redirected to + [pydvl.parallel][pydvl.parallel] only until v0.9.0. Please update your + imports. +""" +import logging + +from ..parallel.backend import * +from ..parallel.config import * +from ..parallel.futures import * +from ..parallel.map_reduce import * + +log = logging.getLogger(__name__) + +# This string for the benefit of deprecation searches: +# remove_in="0.9.0" +log.warning( + "Importing parallel tools from pydvl.utils is deprecated. " + "Please import directly from pydvl.parallel. " + "Redirected imports will be removed in v0.9.0" +) diff --git a/src/pydvl/value/least_core/common.py b/src/pydvl/value/least_core/common.py index f29e48a4a..2de8e7e3a 100644 --- a/src/pydvl/value/least_core/common.py +++ b/src/pydvl/value/least_core/common.py @@ -7,7 +7,8 @@ import numpy as np from numpy.typing import NDArray -from pydvl.utils import MapReduceJob, ParallelConfig, Status, Utility +from pydvl.parallel import MapReduceJob, ParallelConfig +from pydvl.utils import Status, Utility from pydvl.value import ValuationResult __all__ = [ diff --git a/src/pydvl/value/least_core/montecarlo.py b/src/pydvl/value/least_core/montecarlo.py index fc2f9fe92..23a095278 100644 --- a/src/pydvl/value/least_core/montecarlo.py +++ b/src/pydvl/value/least_core/montecarlo.py @@ -5,10 +5,8 @@ import numpy as np from numpy.typing import NDArray -from pydvl.utils.config import ParallelConfig +from pydvl.parallel import MapReduceJob, ParallelConfig, effective_n_jobs from pydvl.utils.numeric import random_powerset -from pydvl.utils.parallel import MapReduceJob -from pydvl.utils.parallel.backend import effective_n_jobs from pydvl.utils.progress import maybe_progress from pydvl.utils.utility import Utility from pydvl.value.least_core.common import LeastCoreProblem, lc_solve_problem diff --git a/src/pydvl/value/loo/loo.py b/src/pydvl/value/loo/loo.py index 893594260..a507f6aad 100644 --- a/src/pydvl/value/loo/loo.py +++ b/src/pydvl/value/loo/loo.py @@ -4,7 +4,8 @@ from tqdm import tqdm -from pydvl.utils import ParallelConfig, Utility, effective_n_jobs, init_executor +from pydvl.parallel import ParallelConfig, effective_n_jobs, init_executor +from pydvl.utils import Utility from pydvl.value.result import ValuationResult __all__ = ["compute_loo"] diff --git a/src/pydvl/value/semivalues.py b/src/pydvl/value/semivalues.py index 488a25037..97a778ba9 100644 --- a/src/pydvl/value/semivalues.py +++ b/src/pydvl/value/semivalues.py @@ -80,7 +80,8 @@ from deprecate import deprecated from tqdm import tqdm -from pydvl.utils import ParallelConfig, Utility +from pydvl.parallel.config import ParallelConfig +from pydvl.utils import Utility from pydvl.utils.types import Seed from pydvl.value import ValuationResult from pydvl.value.sampler import ( diff --git a/src/pydvl/value/shapley/gt.py b/src/pydvl/value/shapley/gt.py index b78ac346c..b193af6e5 100644 --- a/src/pydvl/value/shapley/gt.py +++ b/src/pydvl/value/shapley/gt.py @@ -17,7 +17,7 @@ ## References [^1]: Jia, R. et al., 2019. - [Towards Efficient Data Valuation Based on the Shapley Value](http://proceedings.mlr.press/v89/jia19a.html). + [Towards Efficient Data Valuation Based on the Shapley Value](https://proceedings.mlr.press/v89/jia19a.html). In: Proceedings of the 22nd International Conference on Artificial Intelligence and Statistics, pp. 1167–1176. PMLR. """ import logging @@ -29,9 +29,9 @@ from numpy.random import SeedSequence from numpy.typing import NDArray -from pydvl.utils import MapReduceJob, ParallelConfig, Utility, maybe_progress +from pydvl.parallel import MapReduceJob, ParallelConfig, effective_n_jobs +from pydvl.utils import Utility, maybe_progress from pydvl.utils.numeric import random_subset_of_size -from pydvl.utils.parallel.backend import effective_n_jobs from pydvl.utils.status import Status from pydvl.utils.types import Seed, ensure_seed_sequence from pydvl.value import ValuationResult diff --git a/src/pydvl/value/shapley/montecarlo.py b/src/pydvl/value/shapley/montecarlo.py index e9aba420b..1046326a5 100644 --- a/src/pydvl/value/shapley/montecarlo.py +++ b/src/pydvl/value/shapley/montecarlo.py @@ -56,10 +56,15 @@ from numpy.typing import NDArray from tqdm import tqdm -from pydvl.utils import effective_n_jobs, init_executor, init_parallel_backend -from pydvl.utils.config import ParallelConfig +from pydvl.parallel import ( + CancellationPolicy, + MapReduceJob, + ParallelConfig, + effective_n_jobs, + init_executor, + init_parallel_backend, +) from pydvl.utils.numeric import random_powerset -from pydvl.utils.parallel import CancellationPolicy, MapReduceJob from pydvl.utils.types import Seed, ensure_seed_sequence from pydvl.utils.utility import Utility from pydvl.value.result import ValuationResult diff --git a/src/pydvl/value/shapley/naive.py b/src/pydvl/value/shapley/naive.py index d1a29e8fd..46f1c11a5 100644 --- a/src/pydvl/value/shapley/naive.py +++ b/src/pydvl/value/shapley/naive.py @@ -1,12 +1,13 @@ import math import warnings from itertools import permutations -from typing import Collection, List +from typing import List import numpy as np from numpy.typing import NDArray -from pydvl.utils import MapReduceJob, ParallelConfig, Utility, maybe_progress, powerset +from pydvl.parallel import MapReduceJob, ParallelConfig +from pydvl.utils import Utility, maybe_progress, powerset from pydvl.utils.status import Status from pydvl.value.result import ValuationResult diff --git a/src/pydvl/value/shapley/owen.py b/src/pydvl/value/shapley/owen.py index 69b5dda89..07b9e972b 100644 --- a/src/pydvl/value/shapley/owen.py +++ b/src/pydvl/value/shapley/owen.py @@ -16,7 +16,8 @@ from numpy.typing import NDArray from tqdm import tqdm -from pydvl.utils import MapReduceJob, ParallelConfig, Utility, random_powerset +from pydvl.parallel import MapReduceJob, ParallelConfig +from pydvl.utils import Utility, random_powerset from pydvl.utils.types import Seed from pydvl.value import ValuationResult from pydvl.value.stopping import MinUpdates diff --git a/src/pydvl/value/shapley/truncated.py b/src/pydvl/value/shapley/truncated.py index 9efe87480..7d65ae06b 100644 --- a/src/pydvl/value/shapley/truncated.py +++ b/src/pydvl/value/shapley/truncated.py @@ -13,7 +13,8 @@ import numpy as np from deprecate import deprecated -from pydvl.utils import ParallelConfig, Utility, running_moments +from pydvl.parallel.config import ParallelConfig +from pydvl.utils import Utility, running_moments from pydvl.value import ValuationResult from pydvl.value.stopping import StoppingCriterion diff --git a/tests/conftest.py b/tests/conftest.py index 2cf9de4f8..c99a38d3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,8 @@ from sklearn import datasets from sklearn.utils import Bunch +from pydvl.parallel.backend import available_cpus from pydvl.utils import Dataset, MemcachedClientConfig -from pydvl.utils.parallel.backend import available_cpus if TYPE_CHECKING: from _pytest.config import Config diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 923d391fb..19c8a0798 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -1,6 +1,6 @@ import pytest -from pydvl.utils.config import ParallelConfig +from pydvl.parallel.config import ParallelConfig @pytest.fixture(scope="module", params=["joblib", "ray-local", "ray-external"]) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index aae0dd416..fbfc3c11b 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -5,7 +5,8 @@ import pytest from numpy.typing import NDArray -from pydvl.utils import MapReduceJob, memcached +from pydvl.parallel import MapReduceJob +from pydvl.utils import memcached logger = logging.getLogger(__name__) diff --git a/tests/utils/test_parallel.py b/tests/utils/test_parallel.py index 8ba145aa8..8818df065 100644 --- a/tests/utils/test_parallel.py +++ b/tests/utils/test_parallel.py @@ -2,14 +2,14 @@ import os import time from functools import partial, reduce -from typing import List, Optional +from typing import Optional import numpy as np import pytest -from pydvl.utils.parallel import MapReduceJob, init_parallel_backend -from pydvl.utils.parallel.backend import effective_n_jobs -from pydvl.utils.parallel.futures import init_executor +from pydvl.parallel import MapReduceJob, init_parallel_backend +from pydvl.parallel.backend import effective_n_jobs +from pydvl.parallel.futures import init_executor from pydvl.utils.types import Seed @@ -97,10 +97,7 @@ def map_reduce_job_and_parameters(parallel_config, n_jobs, request): def test_map_reduce_job(map_reduce_job_and_parameters, indices, expected): map_reduce_job, n_jobs = map_reduce_job_and_parameters result = map_reduce_job(indices)() - if not isinstance(result, np.ndarray): - assert result == expected - else: - assert (result == expected).all() + assert np.all(result == expected) @pytest.mark.parametrize( @@ -233,7 +230,7 @@ def test_future_cancellation(parallel_config): if parallel_config.backend != "ray": pytest.skip("Currently this test only works with Ray") - from pydvl.utils.parallel import CancellationPolicy + from pydvl.parallel import CancellationPolicy with init_executor( config=parallel_config, cancel_futures=CancellationPolicy.NONE diff --git a/tests/value/conftest.py b/tests/value/conftest.py index d19256beb..38cddfdf8 100644 --- a/tests/value/conftest.py +++ b/tests/value/conftest.py @@ -5,8 +5,8 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures +from pydvl.parallel.config import ParallelConfig from pydvl.utils import Dataset, SupervisedModel, Utility -from pydvl.utils.config import ParallelConfig from pydvl.utils.status import Status from pydvl.value import ValuationResult diff --git a/tests/value/shapley/test_knn.py b/tests/value/shapley/test_knn.py index 663937de2..1ca7a1fbc 100644 --- a/tests/value/shapley/test_knn.py +++ b/tests/value/shapley/test_knn.py @@ -5,8 +5,8 @@ from sklearn.metrics import make_scorer from sklearn.neighbors import KNeighborsClassifier +from pydvl.parallel.backend import available_cpus from pydvl.utils.dataset import Dataset -from pydvl.utils.parallel.backend import available_cpus from pydvl.utils.score import Scorer from pydvl.utils.utility import Utility from pydvl.value.shapley.knn import knn_shapley diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index b7961cdb8..3024ed198 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -5,14 +5,8 @@ import pytest from sklearn.linear_model import LinearRegression -from pydvl.utils import ( - Dataset, - GroupedDataset, - MemcachedConfig, - ParallelConfig, - Status, - Utility, -) +from pydvl.parallel.config import ParallelConfig +from pydvl.utils import Dataset, GroupedDataset, MemcachedConfig, Status, Utility from pydvl.utils.numeric import num_samples_permutation_hoeffding from pydvl.utils.score import Scorer, squashed_r2 from pydvl.utils.types import Seed diff --git a/tests/value/test_semivalues.py b/tests/value/test_semivalues.py index ec937d028..ea10dd339 100644 --- a/tests/value/test_semivalues.py +++ b/tests/value/test_semivalues.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from pydvl.utils import ParallelConfig +from pydvl.parallel.config import ParallelConfig from pydvl.value.sampler import ( AntitheticSampler, DeterministicPermutationSampler,