Skip to content

Commit

Permalink
Deprecate RandomState (using names only)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Nov 18, 2024
1 parent ac19bb3 commit 092237a
Show file tree
Hide file tree
Showing 24 changed files with 72 additions and 61 deletions.
4 changes: 4 additions & 0 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from pathlib import Path
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload

import numpy as np
from packaging.version import Version

if TYPE_CHECKING:
from collections.abc import Callable
from importlib.metadata import PackageMetadata


P = ParamSpec("P")
R = TypeVar("R")

_LegacyRandom = int | np.random.RandomState | None


if TYPE_CHECKING:
# type checkers are confused and can only see …core.Array
Expand Down
10 changes: 6 additions & 4 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import sys
import warnings
from collections.abc import Sequence
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, reduce, singledispatch, wraps
Expand Down Expand Up @@ -56,12 +57,13 @@
from anndata import AnnData
from numpy.typing import ArrayLike, DTypeLike, NDArray

from .._compat import _LegacyRandom
from ..neighbors import NeighborsParams, RPForestDict


# e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
# maybe in the future random.Generator
AnyRandom = int | np.random.RandomState | None
SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike = np.random.Generator | np.random.BitGenerator

LegacyUnionType = type(Union[int, str]) # noqa: UP007


Expand Down Expand Up @@ -493,7 +495,7 @@ def moving_average(a: np.ndarray, n: int):
return ret[n - 1 :] / n


def get_random_state(seed: AnyRandom) -> np.random.RandomState:
def _get_legacy_random(seed: _LegacyRandom) -> np.random.RandomState:
if isinstance(seed, np.random.RandomState):
return seed
return np.random.RandomState(seed)
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from typing import Literal

from .._utils import AnyRandom
from .._compat import _LegacyRandom

VisiumSampleID = Literal[
"V1_Breast_Cancer_Block_A_Section_1",
Expand Down Expand Up @@ -63,7 +63,7 @@ def blobs(
n_centers: int = 5,
cluster_std: float = 1.0,
n_observations: int = 640,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> AnnData:
"""\
Gaussian Blobs.
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/pp/_dca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom

_AEType = Literal["zinb-conddisp", "zinb", "nb-conddisp", "nb"]

Expand Down Expand Up @@ -62,7 +62,7 @@ def dca(
early_stop: int = 15,
batch_size: int = 32,
optimizer: str = "RMSprop",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
threads: int | None = None,
learning_rate: float | None = None,
verbose: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/pp/_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom

MIN_VERSION = "2.0"

Expand All @@ -36,7 +36,7 @@ def magic(
n_pca: int | None = 100,
solver: Literal["exact", "approximate"] = "exact",
knn_dist: str = "euclidean",
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
n_jobs: int | None = None,
verbose: bool = False,
copy: bool | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/tl/_phate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom


@old_positionals(
Expand Down Expand Up @@ -49,7 +49,7 @@ def phate(
mds_dist: str = "euclidean",
mds: Literal["classic", "metric", "nonmetric"] = "metric",
n_jobs: int | None = None,
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
verbose: bool | int | None = None,
copy: bool = False,
**kwargs,
Expand Down
12 changes: 6 additions & 6 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from igraph import Graph
from scipy.sparse import csr_matrix

from .._utils import AnyRandom
from .._compat import _LegacyRandom
from ._types import KnnTransformerLike, _Metric, _MetricFn


Expand All @@ -54,13 +54,13 @@ class KwdsForTransformer(TypedDict):
n_neighbors: int
metric: _Metric | _MetricFn
metric_params: Mapping[str, Any]
random_state: AnyRandom
random_state: _LegacyRandom


class NeighborsParams(TypedDict):
n_neighbors: int
method: _Method
random_state: AnyRandom
random_state: _LegacyRandom
metric: _Metric | _MetricFn
metric_kwds: NotRequired[Mapping[str, Any]]
use_rep: NotRequired[str]
Expand All @@ -79,7 +79,7 @@ def neighbors(
transformer: KnnTransformerLike | _KnownTransformer | None = None,
metric: _Metric | _MetricFn = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
key_added: str | None = None,
copy: bool = False,
) -> AnnData | None:
Expand Down Expand Up @@ -521,7 +521,7 @@ def compute_neighbors(
transformer: KnnTransformerLike | _KnownTransformer | None = None,
metric: _Metric | _MetricFn = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> None:
"""\
Compute distances and connectivities of neighbors.
Expand Down Expand Up @@ -757,7 +757,7 @@ def compute_eigen(
n_comps: int = 15,
sym: bool | None = None,
sort: Literal["decrease", "increase"] = "decrease",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
):
"""\
Compute eigen decomposition of transition matrix.
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from matplotlib.colors import Colormap
from scipy.sparse import spmatrix

from ..._compat import _LegacyRandom
from ...tools._draw_graph import _Layout as _LayoutWithoutEqTree
from .._utils import _FontSize, _FontWeight, _LegendLoc

Expand Down Expand Up @@ -210,7 +211,7 @@ def _compute_pos(
adjacency_solid: spmatrix | np.ndarray,
*,
layout: _Layout | None = None,
random_state: _sc_utils.AnyRandom = 0,
random_state: _LegacyRandom = 0,
init_pos: np.ndarray | None = None,
adj_tree=None,
root: int = 0,
Expand Down
5 changes: 3 additions & 2 deletions src/scanpy/preprocessing/_pca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from scipy import sparse
from scipy.sparse import spmatrix

from ..._utils import AnyRandom, Empty
from ..._compat import _LegacyRandom
from ..._utils import Empty

CSMatrix = sparse.csr_matrix | sparse.csc_matrix

Expand Down Expand Up @@ -70,7 +71,7 @@ def pca(
layer: str | None = None,
zero_center: bool | None = True,
svd_solver: SvdSolver | None = None,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
return_info: bool = False,
mask_var: NDArray[np.bool_] | str | None | Empty = _empty,
use_highly_variable: bool | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/preprocessing/_pca/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scipy import sparse
from sklearn.decomposition import PCA

from .._utils import AnyRandom
from ..._compat import _LegacyRandom

CSMatrix = sparse.csr_matrix | sparse.csc_matrix

Expand All @@ -29,7 +29,7 @@ def _pca_compat_sparse(
*,
solver: Literal["arpack", "lobpcg"],
mu: NDArray[np.floating] | None = None,
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
) -> tuple[NDArray[np.floating], PCA]:
"""Sparse PCA for scikit-learn <1.4"""
random_state = check_random_state(random_state)
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/preprocessing/_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from anndata import AnnData

from .._utils import AnyRandom
from .._compat import _LegacyRandom


@old_positionals(
Expand All @@ -36,7 +36,7 @@ def recipe_weinreb17(
cv_threshold: int = 2,
n_pcs: int = 50,
svd_solver="randomized",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
copy: bool = False,
) -> AnnData | None:
"""\
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .core import Scrublet

if TYPE_CHECKING:
from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from ...neighbors import _Metric, _MetricFn


Expand Down Expand Up @@ -58,7 +58,7 @@ def scrublet(
threshold: float | None = None,
verbose: bool = True,
copy: bool = False,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> AnnData | None:
"""\
Predict doublets using Scrublet :cite:p:`Wolock2019`.
Expand Down Expand Up @@ -309,7 +309,7 @@ def _scrublet_call_doublets(
knn_dist_metric: _Metric | _MetricFn = "euclidean",
get_doublet_neighbor_parents: bool = False,
threshold: float | None = None,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
verbose: bool = True,
) -> AnnData:
"""\
Expand Down Expand Up @@ -503,7 +503,7 @@ def scrublet_simulate_doublets(
layer: str | None = None,
sim_doublet_ratio: float = 2.0,
synthetic_doublet_umi_subsampling: float = 1.0,
random_seed: AnyRandom = 0,
random_seed: _LegacyRandom = 0,
) -> AnnData:
"""\
Simulate doublets by adding the counts of random observed transcriptome pairs.
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/preprocessing/_scrublet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy import sparse

from ... import logging as logg
from ..._utils import get_random_state
from ..._utils import _get_legacy_random
from ...neighbors import (
Neighbors,
_get_indices_distances_from_sparse_matrix,
Expand All @@ -21,7 +21,7 @@
from numpy.random import RandomState
from numpy.typing import NDArray

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from ...neighbors import _Metric, _MetricFn

__all__ = ["Scrublet"]
Expand Down Expand Up @@ -73,7 +73,7 @@ class Scrublet:
n_neighbors: InitVar[int | None] = None
expected_doublet_rate: float = 0.1
stdev_doublet_rate: float = 0.02
random_state: InitVar[AnyRandom] = 0
random_state: InitVar[_LegacyRandom] = 0

# private fields

Expand Down Expand Up @@ -174,7 +174,7 @@ def __post_init__(
counts_obs: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer],
total_counts_obs: NDArray[np.integer] | None,
n_neighbors: int | None,
random_state: AnyRandom,
random_state: _LegacyRandom,
) -> None:
self._counts_obs = sparse.csc_matrix(counts_obs)
self._total_counts_obs = (
Expand All @@ -187,7 +187,7 @@ def __post_init__(
if n_neighbors is None
else n_neighbors
)
self._random_state = get_random_state(random_state)
self._random_state = _get_legacy_random(random_state)

def simulate_doublets(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/preprocessing/_scrublet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from typing import Literal

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from .core import Scrublet


Expand Down Expand Up @@ -49,7 +49,7 @@ def truncated_svd(
self: Scrublet,
n_prin_comps: int = 30,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
algorithm: Literal["arpack", "randomized"] = "arpack",
) -> None:
if self._counts_sim_norm is None:
Expand All @@ -68,7 +68,7 @@ def pca(
self: Scrublet,
n_prin_comps: int = 50,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack",
) -> None:
if self._counts_sim_norm is None:
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from scanpy.preprocessing._utils import _get_mean_var

from ..._utils import get_random_state
from ..._utils import _get_legacy_random

if TYPE_CHECKING:
from numpy.typing import NDArray

from ..._utils import AnyRandom
from .._compat import _LegacyRandom


def sparse_multiply(
Expand Down Expand Up @@ -47,10 +47,10 @@ def subsample_counts(
*,
rate: float,
original_totals,
random_seed: AnyRandom = 0,
random_seed: _LegacyRandom = 0,
) -> tuple[sparse.csr_matrix | sparse.csc_matrix, NDArray[np.int64]]:
if rate < 1:
random_seed = get_random_state(random_seed)
random_seed = _get_legacy_random(random_seed)
E.data = random_seed.binomial(np.round(E.data).astype(int), rate)
current_totals = np.asarray(E.sum(1)).squeeze()
unsampled_orig_totals = original_totals - current_totals
Expand Down
Loading

0 comments on commit 092237a

Please sign in to comment.