Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace option to subsample and rename function to sample #943

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bf922e1
Add replace option to subsample.
gokceneraslan Dec 2, 2019
32baba1
Merge branch 'master' into withreplacement
gokceneraslan Apr 20, 2020
671ec71
Add sc.pp.sample with axis argument.
gokceneraslan Apr 20, 2020
9e0739b
Fix fraction doc
gokceneraslan Apr 20, 2020
8ec8cf3
Add to release notes
gokceneraslan Apr 20, 2020
cdf4c65
Merge branch 'main' into withreplacement
flying-sheep Nov 14, 2024
fdf524a
refactor
flying-sheep Nov 14, 2024
061a19d
Refactor tests
flying-sheep Nov 14, 2024
8528f2d
Merge branch 'main' into withreplacement
flying-sheep Nov 14, 2024
06d4280
handle array case in test
flying-sheep Nov 14, 2024
6eeab2e
Test errors
flying-sheep Nov 14, 2024
b1f5061
prettier deprecations
flying-sheep Nov 14, 2024
cec8aff
docs
flying-sheep Nov 14, 2024
daa147e
ignore dask warning correctly
flying-sheep Nov 14, 2024
3c31abd
sig exception
flying-sheep Nov 14, 2024
d350411
WIP
flying-sheep Nov 18, 2024
f02725a
Merge branch 'main' into withreplacement
flying-sheep Nov 19, 2024
c24e9b2
remove duplicate _LegacyRandom
flying-sheep Nov 19, 2024
e246f02
undo compat thing
flying-sheep Nov 19, 2024
4ad40b7
fix backwards compat
flying-sheep Nov 19, 2024
1b8c81e
Use fake Generator
flying-sheep Nov 19, 2024
594d961
backwards compat test
flying-sheep Nov 19, 2024
00fdd77
Merge branch 'main' into withreplacement
flying-sheep Nov 21, 2024
59a171c
Fix tests for old Pythons
flying-sheep Nov 21, 2024
59adc76
test that random state is modified
flying-sheep Nov 21, 2024
ef27db0
Fix util
flying-sheep Nov 21, 2024
c471e94
types
flying-sheep Nov 21, 2024
3028dff
move deprecated stuff
flying-sheep Nov 21, 2024
f11b6ba
Use deprecation decorator
flying-sheep Nov 21, 2024
735f00a
relnote
flying-sheep Nov 21, 2024
4d54700
Merge branch 'pa/deprecated' into withreplacement
flying-sheep Nov 21, 2024
0a5b284
fix dask warning stuff
flying-sheep Nov 21, 2024
0ca9411
oops
flying-sheep Nov 21, 2024
f587cdf
Merge branch 'main' into withreplacement
flying-sheep Nov 22, 2024
396b21a
Bump numpy to version that has get_bit_generator
flying-sheep Nov 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/deprecated.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@

pp.filter_genes_dispersion
pp.normalize_per_cell
pp.subsample
```
2 changes: 1 addition & 1 deletion docs/api/preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and
pp.normalize_total
pp.regress_out
pp.scale
pp.subsample
pp.sample
pp.downsample_counts
```

Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/943.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan` & {smaller}`P Angerer`
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ classifiers = [
]
dependencies = [
"anndata>=0.8",
"numpy>=1.23",
"numpy>=1.24",
"matplotlib>=3.6",
"pandas >=1.5",
"scipy>=1.8",
Expand Down
41 changes: 40 additions & 1 deletion src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import warnings
from dataclasses import dataclass, field
from functools import cache, partial, wraps
from functools import WRAPPER_ASSIGNMENTS, cache, partial, wraps
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload
Expand Down Expand Up @@ -211,3 +211,42 @@
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
)
raise ValueError(msg)


def _legacy_numpy_gen(
random_state: _LegacyRandom | None = None,
) -> np.random.Generator:
"""Return a random generator that behaves like the legacy one."""

if random_state is not None:
if isinstance(random_state, np.random.RandomState):
np.random.set_state(random_state.get_state(legacy=False))
return _FakeRandomGen(random_state)

Check warning on line 224 in src/scanpy/_compat.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_compat.py#L223-L224

Added lines #L223 - L224 were not covered by tests
np.random.seed(random_state)
return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator()))


class _FakeRandomGen(np.random.Generator):
_state: np.random.RandomState

def __init__(self, random_state: np.random.RandomState) -> None:
self._state = random_state

@classmethod
def _delegate(cls) -> None:
for name, meth in np.random.Generator.__dict__.items():
if name.startswith("_") or not callable(meth):
continue

def mk_wrapper(name: str):
# Old pytest versions try to run the doctests
@wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"})
def wrapper(self: _FakeRandomGen, *args, **kwargs):
return getattr(self._state, name)(*args, **kwargs)

return wrapper

setattr(cls, name, mk_wrapper(name))


_FakeRandomGen._delegate()
4 changes: 3 additions & 1 deletion src/scanpy/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ..neighbors import neighbors
from ._combat import combat
from ._deprecated.highly_variable_genes import filter_genes_dispersion
from ._deprecated.sampling import subsample
from ._highly_variable_genes import highly_variable_genes
from ._normalization import normalize_total
from ._pca import pca
Expand All @@ -17,8 +18,8 @@
log1p,
normalize_per_cell,
regress_out,
sample,
sqrt,
subsample,
)

__all__ = [
Expand All @@ -40,6 +41,7 @@
"log1p",
"normalize_per_cell",
"regress_out",
"sample",
"scale",
"sqrt",
"subsample",
Expand Down
60 changes: 60 additions & 0 deletions src/scanpy/preprocessing/_deprecated/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from ..._compat import _legacy_numpy_gen, old_positionals
from .._simple import sample

if TYPE_CHECKING:
import numpy as np
from anndata import AnnData
from numpy.typing import NDArray
from scipy.sparse import csc_matrix, csr_matrix

from ..._compat import _LegacyRandom

CSMatrix = csr_matrix | csc_matrix


@old_positionals("n_obs", "random_state", "copy")
def subsample(
data: AnnData | np.ndarray | CSMatrix,
fraction: float | None = None,
*,
n_obs: int | None = None,
random_state: _LegacyRandom = 0,
copy: bool = False,
) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None:
"""\
Subsample to a fraction of the number of observations.

.. deprecated:: 1.11.0

Use :func:`~scanpy.pp.sample` instead.

Parameters
----------
data
The (annotated) data matrix of shape `n_obs` × `n_vars`.
Rows correspond to cells and columns to genes.
fraction
Subsample to this `fraction` of the number of observations.
n_obs
Subsample to this number of observations.
random_state
Random seed to change subsampling.
copy
If an :class:`~anndata.AnnData` is passed,
determines whether a copy is returned.

Returns
-------
Returns `X[obs_indices], obs_indices` if data is array-like, otherwise
subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or
returns a subsampled copy of it (`copy == True`).
"""

rng = _legacy_numpy_gen(random_state)
return sample(
data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0
)
150 changes: 105 additions & 45 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from functools import singledispatch
from itertools import repeat
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, TypeVar, overload

import numba
import numpy as np
Expand All @@ -22,6 +22,7 @@
from .._settings import settings as sett
from .._utils import (
_check_array_function_arguments,
_resolve_axis,
axis_sum,
is_backed_type,
raise_not_implemented_error_if_backed_type,
Expand All @@ -33,24 +34,24 @@
from ._distributed import materialize_as_ndarray
from ._utils import _to_dense

# install dask if available
try:
import dask.array as da
except ImportError:
da = None

# backwards compat
from ._deprecated.highly_variable_genes import filter_genes_dispersion # noqa: F401

if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Sequence
from numbers import Number
from typing import Literal

import pandas as pd
from numpy.typing import NDArray
from scipy.sparse import csc_matrix

from .._compat import DaskArray, _LegacyRandom
from .._utils import RNGLike, SeedLike

CSMatrix = csr_matrix | csc_matrix


@old_positionals(
Expand Down Expand Up @@ -825,67 +826,126 @@
return np.vstack(responses_chunk_list)


@old_positionals("n_obs", "random_state", "copy")
def subsample(
data: AnnData | np.ndarray | spmatrix,
@overload
def sample(
data: AnnData,
fraction: float | None = None,
*,
n_obs: int | None = None,
random_state: _LegacyRandom = 0,
n: int | None = None,
rng: RNGLike | SeedLike | None = 0,
copy: Literal[False] = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
) -> None: ...
@overload
def sample(
data: AnnData,
fraction: float | None = None,
*,
n: int | None = None,
rng: RNGLike | SeedLike | None = None,
copy: Literal[True],
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
) -> AnnData: ...
@overload
def sample(
data: np.ndarray | CSMatrix,
fraction: float | None = None,
*,
n: int | None = None,
rng: RNGLike | SeedLike | None = None,
copy: bool = False,
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
) -> tuple[np.ndarray | CSMatrix, NDArray[np.int64]]: ...
def sample(
data: AnnData | np.ndarray | CSMatrix,
fraction: float | None = None,
*,
n: int | None = None,
rng: RNGLike | SeedLike | None = None,
copy: bool = False,
) -> AnnData | tuple[np.ndarray | spmatrix, NDArray[np.int64]] | None:
replace: bool = False,
axis: Literal["obs", 0, "var", 1] = "obs",
) -> AnnData | None | tuple[np.ndarray | CSMatrix, NDArray[np.int64]]:
"""\
Subsample to a fraction of the number of observations.
Sample observations or variables with or without replacement.

Parameters
----------
data
The (annotated) data matrix of shape `n_obs` × `n_vars`.
Rows correspond to cells and columns to genes.
fraction
Subsample to this `fraction` of the number of observations.
n_obs
Subsample to this number of observations.
Sample to this `fraction` of the number of observations or variables.
This can be larger than 1.0, if `replace=True`.
See `axis` and `replace`.
n
Sample to this number of observations or variables. See `axis`.
random_state
Random seed to change subsampling.
copy
If an :class:`~anndata.AnnData` is passed,
determines whether a copy is returned.
replace
If True, samples are drawn with replacement.
axis
Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1).

Returns
-------
Returns `X[obs_indices], obs_indices` if data is array-like, otherwise
subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or
returns a subsampled copy of it (`copy == True`).
If `isinstance(data, AnnData)` and `copy=False`,
this function returns `None`. Otherwise:

`data[indices, :]` | `data[:, indices]` (depending on `axis`)
If `data` is array-like or `copy=True`, returns the subset.
`indices` : numpy.ndarray
If `data` is array-like, also returns the indices into the original.
"""
np.random.seed(random_state)
old_n_obs = data.n_obs if isinstance(data, AnnData) else data.shape[0]
if n_obs is not None:
new_n_obs = n_obs
elif fraction is not None:
if fraction > 1 or fraction < 0:
raise ValueError(f"`fraction` needs to be within [0, 1], not {fraction}")
new_n_obs = int(fraction * old_n_obs)
logg.debug(f"... subsampled to {new_n_obs} data points")
else:
raise ValueError("Either pass `n_obs` or `fraction`.")
obs_indices = np.random.choice(old_n_obs, size=new_n_obs, replace=False)
if isinstance(data, AnnData):
if data.isbacked:
if copy:
return data[obs_indices].to_memory()
else:
raise NotImplementedError(
"Inplace subsampling is not implemented for backed objects."
)
else:
if copy:
return data[obs_indices].copy()
else:
data._inplace_subset_obs(obs_indices)
axis, axis_name = _resolve_axis(axis)
old_n = data.shape[axis]
match (fraction, n):
case (None, None):
msg = "Either `fraction` or `n` must be set."
raise TypeError(msg)
case (None, _):
pass
case (_, None):
if fraction < 0:
msg = f"`{fraction=}` needs to be nonnegative."
raise ValueError(msg)
if not replace and fraction > 1:
msg = f"If `replace=False`, `{fraction=}` needs to be within [0, 1]."
raise ValueError(msg)
n = int(fraction * old_n)
logg.debug(f"... sampled to {n} {axis_name}")
case _:
msg = "Providing both `fraction` and `n` is not allowed."
raise TypeError(msg)
del fraction

rng = np.random.default_rng(rng)
indices = rng.choice(old_n, size=n, replace=replace)
subset = data[indices] if axis_name == "obs" else data[:, indices]

if not isinstance(data, AnnData):
assert not isinstance(subset, AnnData)
if copy:
subset = subset.copy()

Check warning on line 935 in src/scanpy/preprocessing/_simple.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_simple.py#L935

Added line #L935 was not covered by tests
return subset, indices
assert isinstance(subset, AnnData)
if copy:
return subset.to_memory() if data.isbacked else subset.copy()

# in-place
if data.isbacked:
msg = "Inplace sampling (`copy=False`) is not implemented for backed objects."
raise NotImplementedError(msg)
if axis_name == "obs":
data._inplace_subset_obs(indices)
else:
X = data
return X[obs_indices], obs_indices
data._inplace_subset_var(indices)


@renamed_arg("target_counts", "counts_per_cell")
Expand Down
1 change: 1 addition & 0 deletions tests/test_package_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class ExpectedSig(TypedDict):
copy_sigs["sc.pp.filter_cells"] = None # unclear `inplace` situation
copy_sigs["sc.pp.filter_genes"] = None # unclear `inplace` situation
copy_sigs["sc.pp.subsample"] = None # returns indices along matrix
copy_sigs["sc.pp.sample"] = None # returns indices along matrix
# partial exceptions: “data” instead of “adata”
copy_sigs["sc.pp.log1p"]["first_name"] = "data"
copy_sigs["sc.pp.normalize_per_cell"]["first_name"] = "data"
Expand Down
Loading
Loading