Skip to content

Commit

Permalink
ENH Split vst and fit_size_factors functions into fit & transform (#185)
Browse files Browse the repository at this point in the history
* split vst functions into fit & transform

* fix leakage of size factors

* handle exceptions and document specific cases and functions

* update tests

* fix docstring

* CI: Add ruff pre-commit hook (#192)

* CI: add ruff

* ci: fix ruff line length argument

* ci: update ruff configuration, remove flake8 and isort

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* CI: integrate isort options to ruff, remove nbqa

* CI: fix force-single-line typo

* CI: add ruff exceptions for sphinx gallery examples

* chore: fix linting

* chore: linting

* ci: add exception to linting for sphinx examples

* docs: renove extra period in sphinx examples

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* CI Remove precommit from workflow (#195)

* CI  remove test_docstrings (duplicate with ruff, which runs pydocstyle) (#196)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor: minor docstring and line comment edits

* refactor: minor docstring and line comment edits

---------

Co-authored-by: SimonGrouard <[email protected]>
Co-authored-by: Boris Muzellec <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adam Gayoso <[email protected]>
Co-authored-by: Boris MUZELLEC <[email protected]>
  • Loading branch information
6 people authored Nov 24, 2023
1 parent e64c413 commit 7ede677
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 189 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ repos:
rev: v1.7.0
hooks:
- id: mypy
exclude: ^(tests/|docs/source/conf.py)

exclude: ^(tests/|docs/source/conf.py)
121 changes: 105 additions & 16 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from pydeseq2.default_inference import DefaultInference
from pydeseq2.inference import Inference
from pydeseq2.preprocessing import deseq2_norm
from pydeseq2.preprocessing import deseq2_norm_fit
from pydeseq2.preprocessing import deseq2_norm_transform
from pydeseq2.utils import build_design_matrix
from pydeseq2.utils import dispersion_trend
from pydeseq2.utils import make_scatter
Expand Down Expand Up @@ -145,6 +146,19 @@ class DeseqDataSet(ad.AnnData):
quiet : bool
Suppress deseq2 status updates during fit.
fit_type: str
Either "parametric" or "mean" for the type of fitting of dispersions to the
mean intensity. "parametric": fit a dispersion-mean relation via a robust
gamma-family GLM. "mean": use the mean of gene-wise dispersion estimates.
(default: ``"parametric"``).
logmeans: numpy.ndarray
Gene-wise mean log counts, computed in ``preprocessing.deseq2_norm_fit()``.
filtered_genes: numpy.ndarray
Genes whose log means are different from -∞, computed in
preprocessing.deseq2_norm_fit().
References
----------
.. bibliography::
Expand Down Expand Up @@ -252,6 +266,8 @@ def __init__(
self.min_replicates = min_replicates
self.beta_tol = beta_tol
self.quiet = quiet
self.logmeans = None
self.filtered_genes = None

# Initialize the inference object.
self.inference = inference or DefaultInference()
Expand All @@ -261,8 +277,30 @@ def vst(
use_design: bool = False,
fit_type: Literal["parametric", "mean"] = "parametric",
) -> None:
"""Fit a variance stabilizing transformation, and apply it to normalized counts.
Results are stored in ``dds.layers["vst_counts"]``.
Parameters
----------
use_design : bool
Whether to use the full design matrix to fit dispersions and the trend curve.
If False, only an intercept is used. (default: ``False``).
fit_type: str
Either "parametric" or "mean" for the type of fitting of dispersions to the
mean intensity. "parametric": fit a dispersion-mean relation via a robust
gamma-family GLM. "mean": use the mean of gene-wise dispersion estimates.
(default: ``"parametric"``).
"""
Fit a variance stabilizing transformation, and apply it to normalized counts.
self.vst_fit(use_design=use_design, fit_type=fit_type)
self.layers["vst_counts"] = self.vst_transform()

def vst_fit(
self,
use_design: bool = False,
fit_type: Literal["parametric", "mean"] = "parametric",
) -> None:
"""Fit a variance stabilizing transformation.
Results are stored in ``dds.layers["vst_counts"]``.
Expand All @@ -277,9 +315,12 @@ def vst(
gamma-family GLM. mean - use the mean of gene-wise dispersion estimates.
(default: ``"parametric"``).
"""
# Start by fitting median-of-ratio size factors, if not already present.
if "size_factors" not in self.obsm:
self.fit_size_factors()
self.fit_type = fit_type # to re-use inside vst_transform

# Start by fitting median-of-ratio size factors if not already present,
# or if they were computed iteratively
if "size_factors" not in self.obsm or self.logmeans is None:
self.fit_size_factors() # by default, fit_type != "iterative"

if use_design:
# Check that the dispersion trend curve was fitted. If not, fit it.
Expand All @@ -294,33 +335,77 @@ def vst(
)
# Fit the trend curve with an intercept design
self.fit_genewise_dispersions()
if fit_type == "parametric":
if self.fit_type == "parametric":
self.fit_dispersion_trend()

# Restore the design matrix and free buffer
self.obsm["design_matrix"] = self.obsm["design_matrix_buffer"].copy()
del self.obsm["design_matrix_buffer"]

# Apply VST
if fit_type == "parametric":
def vst_transform(self, counts: Optional[np.ndarray] = None) -> np.ndarray:
"""Apply the variance stabilizing transformation.
Uses the results from the ``vst_fit`` method.
Parameters
----------
counts : numpy.ndarray
Counts to transform. If ``None``, use the counts from the current dataset.
(default: ``None``).
Returns
-------
numpy.ndarray
Variance stabilized counts.
"""
if "size_factors" not in self.obsm:
raise RuntimeError(
"The vst_fit method should be called prior to vst_transform."
)

if counts is None:
# the transformed counts will be the current ones
normed_counts = self.layers["normed_counts"]
else:
if self.logmeans is None:
# the size factors were still computed iteratively
warnings.warn(
"The size factors were fitted iteratively. They will "
"be re-computed with the counts to be transformed. In a train/test "
"setting with a downstream task, this would result in a leak of "
"data from test to train set.",
UserWarning,
stacklevel=2,
)
logmeans, filtered_genes = deseq2_norm_fit(counts)
else:
logmeans, filtered_genes = self.logmeans, self.filtered_genes

normed_counts, _ = deseq2_norm_transform(counts, logmeans, filtered_genes)

if self.fit_type == "parametric":
a0, a1 = self.uns["trend_coeffs"]
cts = self.layers["normed_counts"]
self.layers["vst_counts"] = np.log2(
(1 + a1 + 2 * a0 * cts + 2 * np.sqrt(a0 * cts * (1 + a1 + a0 * cts)))
return np.log2(
(
1
+ a1
+ 2 * a0 * normed_counts
+ 2 * np.sqrt(a0 * normed_counts * (1 + a1 + a0 * normed_counts))
)
/ (4 * a0)
)
elif fit_type == "mean":
elif self.fit_type == "mean":
gene_dispersions = self.varm["genewise_dispersions"]
use_for_mean = gene_dispersions > 10 * self.min_disp
mean_disp = trim_mean(gene_dispersions[use_for_mean], proportiontocut=0.001)
self.layers["vst_counts"] = (
2 * np.arcsinh(np.sqrt(mean_disp * self.layers["normed_counts"]))
return (
2 * np.arcsinh(np.sqrt(mean_disp * normed_counts))
- np.log(mean_disp)
- np.log(4)
) / np.log(2)
else:
raise NotImplementedError(
f"Found fit_type '{fit_type}'. Expected 'parametric' or 'mean'."
f"Found fit_type '{self.fit_type}'. Expected 'parametric' or 'mean'."
)

def deseq2(self) -> None:
Expand Down Expand Up @@ -379,7 +464,11 @@ def fit_size_factors(
)
self._fit_iterate_size_factors()
else:
self.layers["normed_counts"], self.obsm["size_factors"] = deseq2_norm(self.X)
self.logmeans, self.filtered_genes = deseq2_norm_fit(self.X)
(
self.layers["normed_counts"],
self.obsm["size_factors"],
) = deseq2_norm_transform(self.X, self.logmeans, self.filtered_genes)
end = time.time()

if not self.quiet:
Expand Down
64 changes: 62 additions & 2 deletions pydeseq2/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
def deseq2_norm(
counts: Union[pd.DataFrame, np.ndarray]
) -> Tuple[Union[pd.DataFrame, np.ndarray], Union[pd.DataFrame, np.ndarray]]:
"""
Return normalized counts and size_factors.
"""Return normalized counts and size_factors.
Uses the median of ratios method.
Expand All @@ -27,12 +26,73 @@ def deseq2_norm(
size_factors : pandas.DataFrame or ndarray
DESeq2 normalization factors.
"""
logmeans, filtered_genes = deseq2_norm_fit(counts)
deseq2_counts, size_factors = deseq2_norm_transform(counts, logmeans, filtered_genes)
return deseq2_counts, size_factors


def deseq2_norm_fit(
counts: Union[pd.DataFrame, np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
"""Return ``logmeans`` and ``filtered_genes``, needed in the median of ratios method.
``Logmeans`` and ``filtered_genes`` can then be used to normalize external datasets.
Parameters
----------
counts : pandas.DataFrame or ndarray
Raw counts. One column per gene, one row per sample.
Returns
-------
logmeans : ndarray
Gene-wise mean log counts.
filtered_genes : ndarray
Genes whose log means are different from -∞.
"""
# Compute gene-wise mean log counts
with np.errstate(divide="ignore"): # ignore division by zero warnings
log_counts = np.log(counts)
logmeans = log_counts.mean(0)
# Filter out genes with -∞ log means
filtered_genes = ~np.isinf(logmeans)

return logmeans, filtered_genes


def deseq2_norm_transform(
counts: Union[pd.DataFrame, np.ndarray],
logmeans: np.ndarray,
filtered_genes: np.ndarray,
) -> Tuple[Union[pd.DataFrame, np.ndarray], Union[pd.DataFrame, np.ndarray]]:
"""Return normalized counts and size factors from the median of ratios method.
Can be applied on external dataset, using the ``logmeans`` and ``filtered_genes``
previously computed in the ``fit`` function.
Parameters
----------
counts : pandas.DataFrame or ndarray
Raw counts. One column per gene, one row per sample.
logmeans : ndarray
Gene-wise mean log counts.
filtered_genes : ndarray
Genes whose log means are different from -∞.
Returns
-------
deseq2_counts : pandas.DataFrame or ndarray
DESeq2 normalized counts.
One column per gene, rows are indexed by sample barcodes.
size_factors : pandas.DataFrame or ndarray
DESeq2 normalization factors.
"""
with np.errstate(divide="ignore"): # ignore division by zero warnings
log_counts = np.log(counts)
# Subtract filtered log means from log counts
if isinstance(log_counts, pd.DataFrame):
log_ratios = log_counts.loc[:, filtered_genes] - logmeans[filtered_genes]
Expand Down
Loading

0 comments on commit 7ede677

Please sign in to comment.