From d3504114d2d36580b08ee7b0dc794ab7d0b24f1b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 18 Nov 2024 16:57:27 +0100 Subject: [PATCH] WIP --- src/scanpy/_compat.py | 22 ++++++++++++++++++++++ src/scanpy/preprocessing/_simple.py | 27 ++++++++++++++------------- tests/test_utils.py | 21 ++++++++++++++++++++- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index c5fa4dbe84..d3de7b73ef 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -9,12 +9,14 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload +import numpy as np from packaging.version import Version if TYPE_CHECKING: from collections.abc import Callable from importlib.metadata import PackageMetadata + P = ParamSpec("P") R = TypeVar("R") @@ -194,3 +196,23 @@ def _numba_threading_layer() -> Layer: f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" ) raise ValueError(msg) + + +_LegacyRandom = int | np.random.RandomState | None + + +def _legacy_numpy_gen( + random_state: _LegacyRandom | None = None, +) -> np.random.RandomState: + """Return a random generator that behaves like the legacy one.""" + + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return random_state + np.random.seed(random_state) + state = np.random.get_state(legacy=True) + assert isinstance(state, tuple) + bit_gen = np.random.MT19937() + bit_gen.state = state + return np.random.RandomState(bit_gen) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index e578cfdb90..9ec518f412 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -18,7 +18,7 @@ from sklearn.utils import check_array, sparsefuncs from .. import logging as logg -from .._compat import njit, old_positionals +from .._compat import _legacy_numpy_gen, njit, old_positionals from .._settings import settings as sett from .._utils import ( _check_array_function_arguments, @@ -51,8 +51,8 @@ from numpy.typing import NDArray from scipy.sparse import csc_matrix - from .._compat import DaskArray - from .._utils import AnyRandom + from .._compat import DaskArray, _LegacyRandom + from .._utils import RNGLike, SeedLike CSMatrix = csr_matrix | csc_matrix @@ -834,7 +834,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - random_state: AnyRandom = 0, + rng: RNGLike | SeedLike | None = 0, copy: Literal[False] = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -845,7 +845,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - random_state: AnyRandom = 0, + rng: RNGLike | SeedLike | None = None, copy: Literal[True], replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -856,7 +856,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - random_state: AnyRandom = 0, + rng: RNGLike | SeedLike | None = None, copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -866,7 +866,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - random_state: AnyRandom = 0, + rng: RNGLike | SeedLike | None = None, copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -927,8 +927,9 @@ def sample( raise TypeError(msg) del fraction - np.random.seed(random_state) - indices = np.random.choice(old_n, size=n, replace=replace) + if not isinstance(rng, np.random.RandomState): + rng = np.random.default_rng(rng) + indices = rng.choice(old_n, size=n, replace=replace) subset = data[indices] if axis_name == "obs" else data[:, indices] if not isinstance(data, AnnData): @@ -956,7 +957,7 @@ def subsample( fraction: float | None = None, *, n_obs: int | None = None, - random_state: AnyRandom = 0, + random_state: _LegacyRandom = 0, copy: bool = False, ) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None: """\ @@ -991,7 +992,7 @@ def subsample( data=data, fraction=fraction, n=n_obs, - random_state=random_state, + rng=_legacy_numpy_gen(random_state), copy=copy, replace=False, axis=0, @@ -1004,7 +1005,7 @@ def downsample_counts( counts_per_cell: int | Collection[int] | None = None, total_counts: int | None = None, *, - random_state: AnyRandom = 0, + random_state: _LegacyRandom = 0, replace: bool = False, copy: bool = False, ) -> AnnData | None: @@ -1140,7 +1141,7 @@ def _downsample_array( col: np.ndarray, target: int, *, - random_state: AnyRandom = 0, + random_state: _LegacyRandom = 0, replace: bool = True, inplace: bool = False, ): diff --git a/tests/test_utils.py b/tests/test_utils.py index f8a38a5f9d..aebd6b7ec5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ from packaging.version import Version from scipy.sparse import csr_matrix, issparse -from scanpy._compat import DaskArray, pkg_version +from scanpy._compat import DaskArray, _legacy_numpy_gen, pkg_version from scanpy._utils import ( axis_mul_or_truediv, axis_sum, @@ -247,3 +247,22 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_ x = da.from_array(np.array(x_data), chunks=2).map_blocks(block_type) result = is_constant(x, axis=axis).compute() np.testing.assert_array_equal(expected, result) + + +@pytest.mark.parametrize("seed", [0, 1, 1256712675]) +@pytest.mark.parametrize("func", ["choice"]) +def test_legacy_numpy_gen(seed: int, func: str): + arr_module = _mk_random(seed, func, legacy=True) + arr_generator = _mk_random(seed, func, legacy=False) + np.testing.assert_array_equal(arr_module, arr_generator) + + +def _mk_random(seed: int, func: str, *, legacy: bool) -> np.ndarray: + np.random.seed(seed) + gen = np.random if legacy else _legacy_numpy_gen() + match func: + case "choice": + arr = np.arange(1000) + return gen.choice(arr, size=(100, 100)) + case _: + pytest.fail(f"Unknown {func=}")