Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Nov 18, 2024
1 parent 3c31abd commit d350411
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 14 deletions.
22 changes: 22 additions & 0 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload

import numpy as np
from packaging.version import Version

if TYPE_CHECKING:
from collections.abc import Callable
from importlib.metadata import PackageMetadata


P = ParamSpec("P")
R = TypeVar("R")

Expand Down Expand Up @@ -194,3 +196,23 @@ def _numba_threading_layer() -> Layer:
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
)
raise ValueError(msg)


_LegacyRandom = int | np.random.RandomState | None


def _legacy_numpy_gen(
random_state: _LegacyRandom | None = None,
) -> np.random.RandomState:
"""Return a random generator that behaves like the legacy one."""

if random_state is not None:
if isinstance(random_state, np.random.RandomState):
np.random.set_state(random_state.get_state(legacy=False))
return random_state
np.random.seed(random_state)

Check warning on line 213 in src/scanpy/_compat.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_compat.py#L210-L213

Added lines #L210 - L213 were not covered by tests
state = np.random.get_state(legacy=True)
assert isinstance(state, tuple)
bit_gen = np.random.MT19937()
bit_gen.state = state
return np.random.RandomState(bit_gen)
27 changes: 14 additions & 13 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sklearn.utils import check_array, sparsefuncs

from .. import logging as logg
from .._compat import njit, old_positionals
from .._compat import _legacy_numpy_gen, njit, old_positionals
from .._settings import settings as sett
from .._utils import (
_check_array_function_arguments,
Expand Down Expand Up @@ -51,8 +51,8 @@
from numpy.typing import NDArray
from scipy.sparse import csc_matrix

from .._compat import DaskArray
from .._utils import AnyRandom
from .._compat import DaskArray, _LegacyRandom
from .._utils import RNGLike, SeedLike

CSMatrix = csr_matrix | csc_matrix

Expand Down Expand Up @@ -834,7 +834,7 @@ def sample(
fraction: float | None = None,
*,
n: int | None = None,
random_state: AnyRandom = 0,
rng: RNGLike | SeedLike | None = 0,
copy: Literal[False] = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
Expand All @@ -845,7 +845,7 @@ def sample(
fraction: float | None = None,
*,
n: int | None = None,
random_state: AnyRandom = 0,
rng: RNGLike | SeedLike | None = None,
copy: Literal[True],
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
Expand All @@ -856,7 +856,7 @@ def sample(
fraction: float | None = None,
*,
n: int | None = None,
random_state: AnyRandom = 0,
rng: RNGLike | SeedLike | None = None,
copy: bool = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
Expand All @@ -866,7 +866,7 @@ def sample(
fraction: float | None = None,
*,
n: int | None = None,
random_state: AnyRandom = 0,
rng: RNGLike | SeedLike | None = None,
copy: bool = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
Expand Down Expand Up @@ -927,8 +927,9 @@ def sample(
raise TypeError(msg)
del fraction

np.random.seed(random_state)
indices = np.random.choice(old_n, size=n, replace=replace)
if not isinstance(rng, np.random.RandomState):
rng = np.random.default_rng(rng)
indices = rng.choice(old_n, size=n, replace=replace)
subset = data[indices] if axis_name == "obs" else data[:, indices]

if not isinstance(data, AnnData):
Expand Down Expand Up @@ -956,7 +957,7 @@ def subsample(
fraction: float | None = None,
*,
n_obs: int | None = None,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
copy: bool = False,
) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None:
"""\
Expand Down Expand Up @@ -991,7 +992,7 @@ def subsample(
data=data,
fraction=fraction,
n=n_obs,
random_state=random_state,
rng=_legacy_numpy_gen(random_state),
copy=copy,
replace=False,
axis=0,
Expand All @@ -1004,7 +1005,7 @@ def downsample_counts(
counts_per_cell: int | Collection[int] | None = None,
total_counts: int | None = None,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
replace: bool = False,
copy: bool = False,
) -> AnnData | None:
Expand Down Expand Up @@ -1140,7 +1141,7 @@ def _downsample_array(
col: np.ndarray,
target: int,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
replace: bool = True,
inplace: bool = False,
):
Expand Down
21 changes: 20 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from packaging.version import Version
from scipy.sparse import csr_matrix, issparse

from scanpy._compat import DaskArray, pkg_version
from scanpy._compat import DaskArray, _legacy_numpy_gen, pkg_version
from scanpy._utils import (
axis_mul_or_truediv,
axis_sum,
Expand Down Expand Up @@ -247,3 +247,22 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_
x = da.from_array(np.array(x_data), chunks=2).map_blocks(block_type)
result = is_constant(x, axis=axis).compute()
np.testing.assert_array_equal(expected, result)


@pytest.mark.parametrize("seed", [0, 1, 1256712675])
@pytest.mark.parametrize("func", ["choice"])
def test_legacy_numpy_gen(seed: int, func: str):
arr_module = _mk_random(seed, func, legacy=True)
arr_generator = _mk_random(seed, func, legacy=False)
np.testing.assert_array_equal(arr_module, arr_generator)


def _mk_random(seed: int, func: str, *, legacy: bool) -> np.ndarray:
np.random.seed(seed)
gen = np.random if legacy else _legacy_numpy_gen()
match func:
case "choice":
arr = np.arange(1000)
return gen.choice(arr, size=(100, 100))
case _:
pytest.fail(f"Unknown {func=}")

0 comments on commit d350411

Please sign in to comment.