diff --git a/docs/source/api/docstrings/pydeseq2.default_inference.DefaultInference.rst b/docs/source/api/docstrings/pydeseq2.default_inference.DefaultInference.rst new file mode 100644 index 00000000..2f7e90bf --- /dev/null +++ b/docs/source/api/docstrings/pydeseq2.default_inference.DefaultInference.rst @@ -0,0 +1,19 @@ +pydeseq2.default_inference.DefaultInference +=========================================== + +.. currentmodule:: pydeseq2.default_inference + +.. autoclass:: DefaultInference + + .. rubric:: Methods + + .. autosummary:: + + ~DefaultInference.lin_reg_mu + ~DefaultInference.irls + ~DefaultInference.alpha_mle + ~DefaultInference.wald_test + ~DefaultInference.fit_rough_dispersions + ~DefaultInference.fit_moments_dispersions + ~DefaultInference.dispersion_trend_gamma_glm + ~DefaultInference.lfc_shrink_nbinom_glm diff --git a/docs/source/api/docstrings/pydeseq2.inference.Inference.rst b/docs/source/api/docstrings/pydeseq2.inference.Inference.rst new file mode 100644 index 00000000..fd04e591 --- /dev/null +++ b/docs/source/api/docstrings/pydeseq2.inference.Inference.rst @@ -0,0 +1,19 @@ +pydeseq2.inference.Inference +============================= + +.. currentmodule:: pydeseq2.inference + +.. autoclass:: Inference + + .. rubric:: Methods + + .. autosummary:: + + ~Inference.lin_reg_mu + ~Inference.irls + ~Inference.alpha_mle + ~Inference.wald_test + ~Inference.fit_rough_dispersions + ~Inference.fit_moments_dispersions + ~Inference.dispersion_trend_gamma_glm + ~Inference.lfc_shrink_nbinom_glm diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index c686873b..edbd320b 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,6 +11,8 @@ PyDESeq2 ~dds.DeseqDataSet ~ds.DeseqStats + ~inference.Inference + ~default_inference.DefaultInference ~utils ~grid_search ~preprocessing diff --git a/docs/source/conf.py b/docs/source/conf.py index 97ac32c8..282f7230 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -113,6 +113,9 @@ napoleon_type_aliases = { "DeseqDataSet": ":class:`DeseqDataSet `", + "Inference": ":class:`Inference `", + "DefaultInference": ":class:`DefaultInference " + "`", } # Add any paths that contain templates here, relative to this directory. diff --git a/examples/plot_minimal_pydeseq2_pipeline.py b/examples/plot_minimal_pydeseq2_pipeline.py index 28b27aab..cf1da45c 100644 --- a/examples/plot_minimal_pydeseq2_pipeline.py +++ b/examples/plot_minimal_pydeseq2_pipeline.py @@ -18,6 +18,7 @@ import pickle as pkl from pydeseq2.dds import DeseqDataSet +from pydeseq2.default_inference import DefaultInference from pydeseq2.ds import DeseqStats from pydeseq2.utils import load_example_data @@ -129,12 +130,13 @@ # log-fold change (LFC) parameters from the data, and stores them. # +inference = DefaultInference(n_cpus=8) dds = DeseqDataSet( counts=counts_df, metadata=metadata, design_factors="condition", refit_cooks=True, - n_cpus=8, + inference=inference, ) # %% @@ -217,7 +219,7 @@ # should be a *fitted* :class:`DeseqDataSet ` # object. -stat_res = DeseqStats(dds, n_cpus=8) +stat_res = DeseqStats(dds, inference=inference) # %% # It also has a set of optional keyword arguments (see the :doc:`API documentation @@ -319,7 +321,7 @@ metadata=metadata, design_factors=["group", "condition"], refit_cooks=True, - n_cpus=8, + inference=inference, ) # %% # .. note:: @@ -354,7 +356,7 @@ # ``contrast=["condition", "B", "A"]``. # -stat_res_B_vs_A = DeseqStats(dds, contrast=["condition", "B", "A"], n_cpus=8) +stat_res_B_vs_A = DeseqStats(dds, contrast=["condition", "B", "A"], inference=inference) # %% # .. note:: @@ -381,7 +383,7 @@ # :class:`DeseqDataSet ` # with ``contrast=["group", "Y", "X"]``, and run the analysis again. -stat_res_Y_vs_X = DeseqStats(dds, contrast=["group", "Y", "X"], n_cpus=8) +stat_res_Y_vs_X = DeseqStats(dds, contrast=["group", "Y", "X"], inference=inference) stat_res_Y_vs_X.summary() # %% diff --git a/examples/plot_pandas_io_example.py b/examples/plot_pandas_io_example.py index a26997b9..d7ca28eb 100644 --- a/examples/plot_pandas_io_example.py +++ b/examples/plot_pandas_io_example.py @@ -23,6 +23,7 @@ import pandas as pd from pydeseq2.dds import DeseqDataSet +from pydeseq2.default_inference import DefaultInference from pydeseq2.ds import DeseqStats # Replace this with the path to directory where you would like results to be saved @@ -121,12 +122,13 @@ # object from the count and metadata data that were just loaded. # +inference = DefaultInference(n_cpus=8) dds = DeseqDataSet( counts=counts_df, metadata=metadata, design_factors="condition", refit_cooks=True, - n_cpus=8, + inference=inference, ) # %% @@ -195,7 +197,7 @@ # compute p-values and adjusted p-values for differential expresion. This is the role of # the :class:`DeseqStats ` class. -stat_res = DeseqStats(dds, n_cpus=8) +stat_res = DeseqStats(dds, inference=inference) # %% # PyDESeq2 computes p-values using Wald tests. This can be done using the diff --git a/examples/plot_step_by_step.py b/examples/plot_step_by_step.py index 0bcadd14..cc98aee4 100644 --- a/examples/plot_step_by_step.py +++ b/examples/plot_step_by_step.py @@ -22,6 +22,7 @@ import pickle as pkl from pydeseq2.dds import DeseqDataSet +from pydeseq2.default_inference import DefaultInference from pydeseq2.ds import DeseqStats from pydeseq2.utils import load_example_data @@ -87,13 +88,14 @@ # in the case of the provided synthetic data, there won't be any Cooks # outliers. +inference = DefaultInference(n_cpus=8) dds = DeseqDataSet( counts=counts_df, metadata=metadata, design_factors="condition", # compare samples based on the "condition" # column ("B" vs "A") refit_cooks=True, - n_cpus=8, + inference=inference, ) # %% diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index e4e5b05f..11a0ab66 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -10,25 +10,16 @@ import anndata as ad # type: ignore import numpy as np import pandas as pd -import statsmodels.api as sm # type: ignore -from joblib import Parallel # type: ignore -from joblib import delayed -from joblib import parallel_backend from scipy.optimize import minimize from scipy.special import polygamma # type: ignore from scipy.stats import f # type: ignore from scipy.stats import trim_mean # type: ignore -from statsmodels.tools.sm_exceptions import DomainWarning # type: ignore +from pydeseq2.default_inference import DefaultInference +from pydeseq2.inference import Inference from pydeseq2.preprocessing import deseq2_norm from pydeseq2.utils import build_design_matrix from pydeseq2.utils import dispersion_trend -from pydeseq2.utils import fit_alpha_mle -from pydeseq2.utils import fit_lin_mu -from pydeseq2.utils import fit_moments_dispersions -from pydeseq2.utils import fit_rough_dispersions -from pydeseq2.utils import get_num_processes -from pydeseq2.utils import irls_solver from pydeseq2.utils import make_scatter from pydeseq2.utils import mean_absolute_deviation from pydeseq2.utils import nb_nll @@ -37,8 +28,6 @@ from pydeseq2.utils import test_valid_counts from pydeseq2.utils import trimmed_mean -# Ignore DomainWarning raised by statsmodels when fitting a Gamma GLM with identity link. -warnings.simplefilter("ignore", DomainWarning) # Ignore AnnData's FutureWarning about implicit data conversion. warnings.simplefilter("ignore", FutureWarning) @@ -103,16 +92,10 @@ class DeseqDataSet(ad.AnnData): .. math:: \vert dev_t - dev_{t+1}\vert / (\vert dev \vert + 0.1) < \beta_{tol}. - n_cpus : int - Number of cpus to use. If None, all available cpus will be used. - (default: ``None``). - - batch_size : int - Number of tasks to allocate to each joblib parallel worker. (default: ``128``). - - joblib_verbosity : int - The verbosity level for joblib tasks. The higher the value, the more updates - are reported. (default: ``0``). + inference : Inference + Implementation of inference routines object instance. + (default: + :class:`DefaultInference `). quiet : bool Suppress deseq2 status updates during fit. @@ -183,9 +166,7 @@ def __init__( refit_cooks: bool = True, min_replicates: int = 7, beta_tol: float = 1e-8, - n_cpus: Optional[int] = None, - batch_size: int = 128, - joblib_verbosity: int = 0, + inference: Optional[Inference] = None, quiet: bool = False, ) -> None: # Initialize the AnnData part @@ -270,11 +251,11 @@ def __init__( self.ref_level = ref_level self.min_replicates = min_replicates self.beta_tol = beta_tol - self.n_processes = get_num_processes(n_cpus) - self.batch_size = batch_size - self.joblib_verbosity = joblib_verbosity self.quiet = quiet + # Initialize the inference object. + self.inference = inference or DefaultInference() + def vst( self, use_design: bool = False, @@ -435,73 +416,41 @@ def fit_genewise_dispersions(self) -> None: len(self.obsm["design_matrix"].value_counts()) == self.obsm["design_matrix"].shape[-1] ): - with parallel_backend("loky", inner_max_num_threads=1): - mu_hat_ = np.array( - Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(fit_lin_mu)( - counts=self.X[:, i], - size_factors=self.obsm["size_factors"], - design_matrix=design_matrix, - min_mu=self.min_mu, - ) - for i in self.non_zero_idx - ) - ) + mu_hat_ = self.inference.lin_reg_mu( + counts=self.X[:, self.non_zero_idx], + size_factors=self.obsm["size_factors"], + design_matrix=design_matrix, + min_mu=self.min_mu, + ) else: - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(irls_solver)( - counts=self.X[:, i], - size_factors=self.obsm["size_factors"], - design_matrix=design_matrix, - disp=self.varm["_MoM_dispersions"][i], - min_mu=self.min_mu, - beta_tol=self.beta_tol, - ) - for i in self.non_zero_idx - ) - - _, mu_hat_, _, _ = zip(*res) - mu_hat_ = np.array(mu_hat_) + _, mu_hat_, _, _ = self.inference.irls( + counts=self.X[:, self.non_zero_idx], + size_factors=self.obsm["size_factors"], + design_matrix=design_matrix, + disp=self.varm["_MoM_dispersions"][self.non_zero_idx], + min_mu=self.min_mu, + beta_tol=self.beta_tol, + ) self.layers["_mu_hat"] = np.full((self.n_obs, self.n_vars), np.NaN) - self.layers["_mu_hat"][:, self.varm["non_zero"]] = mu_hat_.T + self.layers["_mu_hat"][:, self.varm["non_zero"]] = mu_hat_ if not self.quiet: print("Fitting dispersions...", file=sys.stderr) start = time.time() - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(fit_alpha_mle)( - counts=self.X[:, i], - design_matrix=design_matrix, - mu=self.layers["_mu_hat"][:, i], - alpha_hat=self.varm["_MoM_dispersions"][i], - min_disp=self.min_disp, - max_disp=self.max_disp, - ) - # for i in range(num_genes) - for i in self.non_zero_idx - ) + dispersions_, l_bfgs_b_converged_ = self.inference.alpha_mle( + counts=self.X[:, self.non_zero_idx], + design_matrix=design_matrix, + mu=self.layers["_mu_hat"][:, self.non_zero_idx], + alpha_hat=self.varm["_MoM_dispersions"][self.non_zero_idx], + min_disp=self.min_disp, + max_disp=self.max_disp, + ) end = time.time() if not self.quiet: print(f"... done in {end - start:.2f} seconds.\n", file=sys.stderr) - dispersions_, l_bfgs_b_converged_ = zip(*res) - self.varm["genewise_dispersions"] = np.full(self.n_vars, np.NaN) self.varm["genewise_dispersions"][self.varm["non_zero"]] = np.clip( dispersions_, self.min_disp, self.max_disp @@ -529,11 +478,9 @@ def fit_dispersion_trend(self) -> None: self[:, self.non_zero_genes].varm["genewise_dispersions"].copy(), index=self.non_zero_genes, ) - covariates = sm.add_constant( - pd.Series( - 1 / self[:, self.non_zero_genes].varm["_normed_means"], - index=self.non_zero_genes, - ) + covariates = pd.Series( + 1 / self[:, self.non_zero_genes].varm["_normed_means"], + index=self.non_zero_genes, ) for gene in self.non_zero_genes: @@ -549,18 +496,11 @@ def fit_dispersion_trend(self) -> None: coeffs = pd.Series([1.0, 1.0]) while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6: - glm_gamma = sm.GLM( - targets.values, - covariates.values, - family=sm.families.Gamma(link=sm.families.links.identity()), + old_coeffs = coeffs + coeffs, predictions = self.inference.dispersion_trend_gamma_glm( + covariates, targets ) - - res = glm_gamma.fit() - old_coeffs = coeffs.copy() - coeffs = res.params - # Filter out genes that are too far away from the curve before refitting - predictions = covariates.values @ coeffs pred_ratios = ( self[:, covariates.index].varm["genewise_dispersions"] / predictions ) @@ -648,32 +588,22 @@ def fit_MAP_dispersions(self) -> None: if not self.quiet: print("Fitting MAP dispersions...", file=sys.stderr) start = time.time() - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(fit_alpha_mle)( - counts=self.X[:, i], - design_matrix=design_matrix, - mu=self.layers["_mu_hat"][:, i], - alpha_hat=self.varm["fitted_dispersions"][i], - min_disp=self.min_disp, - max_disp=self.max_disp, - prior_disp_var=self.uns["prior_disp_var"].item(), - cr_reg=True, - prior_reg=True, - ) - for i in self.non_zero_idx - ) + dispersions_, l_bfgs_b_converged_ = self.inference.alpha_mle( + counts=self.X[:, self.non_zero_idx], + design_matrix=design_matrix, + mu=self.layers["_mu_hat"][:, self.non_zero_idx], + alpha_hat=self.varm["fitted_dispersions"][self.non_zero_idx], + min_disp=self.min_disp, + max_disp=self.max_disp, + prior_disp_var=self.uns["prior_disp_var"].item(), + cr_reg=True, + prior_reg=True, + ) end = time.time() if not self.quiet: print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr) - dispersions_, l_bfgs_b_converged_ = zip(*res) - self.varm["MAP_dispersions"] = np.full(self.n_vars, np.NaN) self.varm["MAP_dispersions"][self.varm["non_zero"]] = np.clip( dispersions_, self.min_disp, self.max_disp @@ -707,31 +637,19 @@ def fit_LFC(self) -> None: if not self.quiet: print("Fitting LFCs...", file=sys.stderr) start = time.time() - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(irls_solver)( - counts=self.X[:, i], - size_factors=self.obsm["size_factors"], - design_matrix=design_matrix, - disp=self.varm["dispersions"][i], - min_mu=self.min_mu, - beta_tol=self.beta_tol, - ) - for i in self.non_zero_idx - ) + mle_lfcs_, mu_, hat_diagonals_, converged_ = self.inference.irls( + counts=self.X[:, self.non_zero_idx], + size_factors=self.obsm["size_factors"], + design_matrix=design_matrix, + disp=self.varm["dispersions"][self.non_zero_idx], + min_mu=self.min_mu, + beta_tol=self.beta_tol, + ) end = time.time() if not self.quiet: print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr) - MLE_lfcs_, mu_, hat_diagonals_, converged_ = zip(*res) - mu_ = np.array(mu_).T - hat_diagonals_ = np.array(hat_diagonals_).T - self.varm["LFC"] = pd.DataFrame( np.NaN, index=self.var_names, @@ -740,7 +658,7 @@ def fit_LFC(self) -> None: self.varm["LFC"].update( pd.DataFrame( - MLE_lfcs_, + mle_lfcs_, index=self.non_zero_genes, columns=self.obsm["design_matrix"].columns, ) @@ -819,12 +737,13 @@ def _fit_MoM_dispersions(self) -> None: if "normed_counts" not in self.layers: self.fit_size_factors() - rde = fit_rough_dispersions( - self.layers["normed_counts"], - self.obsm["design_matrix"], + normed_counts = self.layers["normed_counts"][:, self.non_zero_idx] + rde = self.inference.fit_rough_dispersions( + normed_counts, + self.obsm["design_matrix"].values, ) - mde = fit_moments_dispersions( - self.layers["normed_counts"], self.obsm["size_factors"] + mde = self.inference.fit_moments_dispersions( + normed_counts, self.obsm["size_factors"] ) alpha_hat = np.minimum(rde, mde) @@ -970,8 +889,7 @@ def _refit_without_outliers( refit_cooks=self.refit_cooks, min_replicates=self.min_replicates, beta_tol=self.beta_tol, - n_cpus=self.n_processes, - batch_size=self.batch_size, + inference=self.inference, ) # Use the same size factors diff --git a/pydeseq2/default_inference.py b/pydeseq2/default_inference.py new file mode 100644 index 00000000..19cefb83 --- /dev/null +++ b/pydeseq2/default_inference.py @@ -0,0 +1,246 @@ +import warnings +from typing import Literal +from typing import Optional +from typing import Tuple + +import numpy as np +import pandas as pd +import statsmodels.api as sm # type: ignore +from joblib import Parallel # type: ignore +from joblib import delayed +from joblib import parallel_backend +from statsmodels.tools.sm_exceptions import DomainWarning # type: ignore + +from pydeseq2 import inference +from pydeseq2 import utils + +# Ignore DomainWarning raised by statsmodels when fitting a Gamma GLM with identity link. +warnings.simplefilter("ignore", DomainWarning) + + +class DefaultInference(inference.Inference): + """Default DESeq2-related inference methods, using scipy/sklearn/numpy. + + This object contains the interface to the default inference routines and uses + joblib internally for parallelization. Inherit this class or its parent to write + custom inference routines. + + Parameters + ---------- + joblib_verbosity : int + The verbosity level for joblib tasks. The higher the value, the more updates + are reported. (default: ``0``). + batch_size : int + Number of tasks to allocate to each joblib parallel worker. (default: ``128``). + n_cpus : int + Number of cpus to use. If None, all available cpus will be used. + (default: ``None``). + backend : str + Joblib backend. + """ + + fit_rough_dispersions = staticmethod(utils.fit_rough_dispersions) # type: ignore + fit_moments_dispersions = staticmethod(utils.fit_moments_dispersions) # type: ignore + + def __init__( + self, + joblib_verbosity: int = 0, + batch_size: int = 128, + n_cpus: Optional[int] = None, + backend: str = "loky", + ): + self._joblib_verbosity = joblib_verbosity + self._batch_size = batch_size + self._n_processes = utils.get_num_processes(n_cpus) + self._backend = backend + + def lin_reg_mu( # noqa: D102 + self, + counts: np.ndarray, + size_factors: np.ndarray, + design_matrix: np.ndarray, + min_mu: float, + ) -> np.ndarray: + with parallel_backend(self._backend, inner_max_num_threads=1): + mu_hat_ = np.array( + Parallel( + n_jobs=self._n_processes, + verbose=self._joblib_verbosity, + batch_size=self._batch_size, + )( + delayed(utils.fit_lin_mu)( + counts=counts[:, i], + size_factors=size_factors, + design_matrix=design_matrix, + min_mu=min_mu, + ) + for i in range(counts.shape[1]) + ) + ) + return mu_hat_.T + + def irls( # noqa: D102 + self, + counts: np.ndarray, + size_factors: np.ndarray, + design_matrix: np.ndarray, + disp: np.ndarray, + min_mu: float, + beta_tol: float, + min_beta: float = -30, + max_beta: float = 30, + optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B", + maxiter: int = 250, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + with parallel_backend(self._backend, inner_max_num_threads=1): + res = Parallel( + n_jobs=self._n_processes, + verbose=self._joblib_verbosity, + batch_size=self._batch_size, + )( + delayed(utils.irls_solver)( + counts=counts[:, i], + size_factors=size_factors, + design_matrix=design_matrix, + disp=disp[i], + min_mu=min_mu, + beta_tol=beta_tol, + min_beta=min_beta, + max_beta=max_beta, + optimizer=optimizer, + maxiter=maxiter, + ) + for i in range(counts.shape[1]) + ) + res = zip(*res) + MLE_lfcs_, mu_hat_, hat_diagonals_, converged_ = (np.array(m) for m in res) + + return ( + MLE_lfcs_, + mu_hat_.T, + hat_diagonals_.T, + converged_, + ) + + def alpha_mle( # noqa: D102 + self, + counts: np.ndarray, + design_matrix: np.ndarray, + mu: np.ndarray, + alpha_hat: np.ndarray, + min_disp: float, + max_disp: float, + prior_disp_var: Optional[float] = None, + cr_reg: bool = True, + prior_reg: bool = False, + optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B", + ) -> Tuple[np.ndarray, np.ndarray]: + with parallel_backend(self._backend, inner_max_num_threads=1): + res = Parallel( + n_jobs=self._n_processes, + verbose=self._joblib_verbosity, + batch_size=self._batch_size, + )( + delayed(utils.fit_alpha_mle)( + counts=counts[:, i], + design_matrix=design_matrix, + mu=mu[:, i], + alpha_hat=alpha_hat[i], + min_disp=min_disp, + max_disp=max_disp, + prior_disp_var=prior_disp_var, + cr_reg=cr_reg, + prior_reg=prior_reg, + optimizer=optimizer, + ) + for i in range(counts.shape[1]) + ) + res = zip(*res) + dispersions_, l_bfgs_b_converged_ = (np.array(m) for m in res) + return dispersions_, l_bfgs_b_converged_ + + def wald_test( # noqa: D102 + self, + design_matrix: np.ndarray, + disp: np.ndarray, + lfc: np.ndarray, + mu: np.ndarray, + ridge_factor: np.ndarray, + contrast: np.ndarray, + lfc_null: np.ndarray, + alt_hypothesis: Optional[ + Literal["greaterAbs", "lessAbs", "greater", "less"] + ] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + num_genes = mu.shape[1] + with parallel_backend("loky", inner_max_num_threads=1): + res = Parallel( + n_jobs=self._n_processes, + verbose=self._joblib_verbosity, + batch_size=self._batch_size, + )( + delayed(utils.wald_test)( + design_matrix=design_matrix, + disp=disp[i], + lfc=lfc[i], + mu=mu[:, i], + ridge_factor=ridge_factor, + contrast=contrast, + lfc_null=lfc_null, # Convert log2 to natural log + alt_hypothesis=alt_hypothesis, + ) + for i in range(num_genes) + ) + res = zip(*res) + pvals, stats, se = (np.array(m) for m in res) + + return pvals, stats, se + + def dispersion_trend_gamma_glm( # noqa: D102 + self, covariates: pd.Series, targets: pd.Series + ) -> Tuple[np.ndarray, np.ndarray]: + covariates_w_intercept = sm.add_constant(covariates) + targets_fit = targets.values + covariates_fit = covariates_w_intercept.values + glm_gamma = sm.GLM( + targets_fit, + covariates_fit, + family=sm.families.Gamma(link=sm.families.links.identity()), + ) + res = glm_gamma.fit() + coeffs = res.params + return (coeffs, covariates_fit @ coeffs) + + def lfc_shrink_nbinom_glm( # noqa: D102 + self, + design_matrix: np.ndarray, + counts: np.ndarray, + size: np.ndarray, + offset: np.ndarray, + prior_no_shrink_scale: float, + prior_scale: float, + optimizer: str, + shrink_index: int, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + with parallel_backend(self._backend, inner_max_num_threads=1): + num_genes = counts.shape[1] + res = Parallel( + n_jobs=self._n_processes, + verbose=self._joblib_verbosity, + batch_size=self._batch_size, + )( + delayed(utils.nbinomGLM)( + design_matrix=design_matrix, + counts=counts[:, i], + size=size[i], + offset=offset, + prior_no_shrink_scale=prior_no_shrink_scale, + prior_scale=prior_scale, + optimizer=optimizer, + shrink_index=shrink_index, + ) + for i in range(num_genes) + ) + res = zip(*res) + lfcs, inv_hessians, l_bfgs_b_converged_ = (np.array(m) for m in res) + return lfcs, inv_hessians, l_bfgs_b_converged_ diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index ee3b07a3..ac197d73 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -8,18 +8,14 @@ import numpy as np import pandas as pd import statsmodels.api as sm # type: ignore -from joblib import Parallel # type: ignore -from joblib import delayed # type: ignore -from joblib import parallel_backend # type: ignore from scipy.optimize import root_scalar # type: ignore from scipy.stats import f # type: ignore from statsmodels.stats.multitest import multipletests # type: ignore from pydeseq2.dds import DeseqDataSet -from pydeseq2.utils import get_num_processes +from pydeseq2.default_inference import DefaultInference +from pydeseq2.inference import Inference from pydeseq2.utils import make_MA_plot -from pydeseq2.utils import nbinomGLM -from pydeseq2.utils import wald_test class DeseqStats: @@ -58,10 +54,6 @@ class DeseqStats: Whether to perform independent filtering to correct p-value trends. (default: ``True``). - n_cpus : int - Number of cpus to use for multiprocessing. - If None, all available CPUs will be used. (default: ``None``). - prior_LFC_var : ndarray Prior variance for LFCs, used for ridge regularization. (default: ``None``). @@ -76,12 +68,10 @@ class DeseqStats: The alternative hypothesis corresponds to what the user wants to find rather than the null hypothesis. (default: ``None``). - batch_size : int - Number of tasks to allocate to each joblib parallel worker. (default: ``128``). - - joblib_verbosity : int - The verbosity level for joblib tasks. The higher the value, the more updates - are reported. (default: ``0``). + inference : Inference + Implementation of inference routines object instance. + (default: + :class:`DefaultInference `). quiet : bool Suppress deseq2 status updates during fit. @@ -149,14 +139,12 @@ def __init__( alpha: float = 0.05, cooks_filter: bool = True, independent_filter: bool = True, - n_cpus: Optional[int] = None, prior_LFC_var: Optional[np.ndarray] = None, lfc_null: float = 0.0, alt_hypothesis: Optional[ Literal["greaterAbs", "lessAbs", "greater", "less"] ] = None, - batch_size: int = 128, - joblib_verbosity: int = 0, + inference: Optional[Inference] = None, quiet: bool = False, ) -> None: assert ( @@ -192,11 +180,11 @@ def __init__( # Set a flag to indicate that LFCs are unshrunk self.shrunk_LFCs = False - self.n_processes = get_num_processes(n_cpus) - self.batch_size = batch_size - self.joblib_verbosity = joblib_verbosity self.quiet = quiet + # Initialize the inference object. + self.inference = inference or DefaultInference() + # If the `refit_cooks` attribute of the dds object is True, check that outliers # were actually refitted. if self.dds.refit_cooks and "replaced" not in self.dds.varm: @@ -287,7 +275,6 @@ def run_wald_test(self) -> None: Get gene-wise p-values for gene over/under-expression.` """ - num_genes = self.dds.n_vars num_vars = self.design_matrix.shape[1] # Raise a warning if LFCs are shrunk. @@ -318,30 +305,20 @@ def run_wald_test(self) -> None: if not self.quiet: print("Running Wald tests...", file=sys.stderr) start = time.time() - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(wald_test)( - design_matrix=design_matrix, - disp=self.dds.varm["dispersions"][i], - lfc=LFCs[i], - mu=mu[:, i], - ridge_factor=ridge_factor, - contrast=self.contrast_vector, - lfc_null=np.log(2) * self.lfc_null, # Convert log2 to natural log - alt_hypothesis=self.alt_hypothesis, - ) - for i in range(num_genes) - ) + pvals, stats, se = self.inference.wald_test( + design_matrix=design_matrix, + disp=self.dds.varm["dispersions"], + lfc=LFCs, + mu=mu, + ridge_factor=ridge_factor, + contrast=self.contrast_vector, + lfc_null=np.log(2) * self.lfc_null, # Convert log2 to natural log + alt_hypothesis=self.alt_hypothesis, + ) end = time.time() if not self.quiet: print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr) - pvals, stats, se = zip(*res) - self.p_values: pd.Series = pd.Series(pvals, index=self.dds.var_names) self.statistics: pd.Series = pd.Series(stats, index=self.dds.var_names) self.SE: pd.Series = pd.Series(se, index=self.dds.var_names) @@ -420,30 +397,20 @@ def lfc_shrink(self, coeff: Optional[str] = None) -> None: if not self.quiet: print("Fitting MAP LFCs...", file=sys.stderr) start = time.time() - with parallel_backend("loky", inner_max_num_threads=1): - res = Parallel( - n_jobs=self.n_processes, - verbose=self.joblib_verbosity, - batch_size=self.batch_size, - )( - delayed(nbinomGLM)( - design_matrix=design_matrix, - counts=self.dds.X[:, i], - size=size[i], - offset=offset, - prior_no_shrink_scale=prior_no_shrink_scale, - prior_scale=prior_scale, - optimizer="L-BFGS-B", - shrink_index=coeff_idx, - ) - for i in self.dds.non_zero_idx - ) + lfcs, inv_hessians, l_bfgs_b_converged_ = self.inference.lfc_shrink_nbinom_glm( + design_matrix=design_matrix, + counts=self.dds.X[:, self.dds.non_zero_idx], + size=size[self.dds.non_zero_idx], + offset=offset, + prior_no_shrink_scale=prior_no_shrink_scale, + prior_scale=prior_scale, + optimizer="L-BFGS-B", + shrink_index=coeff_idx, + ) end = time.time() if not self.quiet: print(f"... done in {end-start:.2f} seconds.\n", file=sys.stderr) - lfcs, inv_hessians, l_bfgs_b_converged_ = zip(*res) - self.LFC.iloc[:, coeff_idx].update( pd.Series( np.array(lfcs)[:, coeff_idx], diff --git a/pydeseq2/inference.py b/pydeseq2/inference.py new file mode 100644 index 00000000..e4e50a66 --- /dev/null +++ b/pydeseq2/inference.py @@ -0,0 +1,362 @@ +from abc import ABC +from abc import abstractmethod +from typing import Literal +from typing import Optional +from typing import Tuple + +import numpy as np +import pandas as pd + + +class Inference(ABC): + """Abstract class with DESeq2-related inference methods.""" + + @abstractmethod + def lin_reg_mu( + self, + counts: np.ndarray, + size_factors: np.ndarray, + design_matrix: np.ndarray, + min_mu: float, + ) -> np.ndarray: + """Estimate mean of negative binomial model using a linear regression. + + Used to initialize genewise dispersion models. + + Parameters + ---------- + counts : ndarray + Raw counts. + + size_factors : ndarray + Sample-wise scaling factors (obtained from median-of-ratios). + + design_matrix : ndarray + Design matrix. + + min_mu : float + Lower threshold for fitted means, for numerical stability. + (default: ``0.5``). + + Returns + ------- + ndarray + Estimated mean. + """ + + @abstractmethod + def irls( + self, + counts: np.ndarray, + size_factors: np.ndarray, + design_matrix: np.ndarray, + disp: np.ndarray, + min_mu: float, + beta_tol: float, + min_beta: float = -30, + max_beta: float = 30, + optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B", + maxiter: int = 250, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + r"""Fit a NB GLM wit log-link to predict counts from the design matrix. + + See equations (1-2) in the DESeq2 paper. + + Parameters + ---------- + counts : ndarray + Raw counts. + + size_factors : ndarray + Sample-wise scaling factors (obtained from median-of-ratios). + + design_matrix : ndarray + Design matrix. + + disp : ndarray + Gene-wise dispersion prior. + + min_mu : ndarray + Lower bound on estimated means, to ensure numerical stability. + (default: ``0.5``). + + beta_tol : float + Stopping criterion for IRWLS: + :math:`\vert dev - dev_{old}\vert / \vert dev + 0.1 \vert < \beta_{tol}`. + (default: ``1e-8``). + + min_beta : float + Lower-bound on LFC. (default: ``-30``). + + max_beta : float + Upper-bound on LFC. (default: ``-30``). + + optimizer : str + Optimizing method to use in case IRLS starts diverging. + Accepted values: 'BFGS' or 'L-BFGS-B'. + NB: only 'L-BFGS-B' ensures that LFCS will + lay in the [min_beta, max_beta] range. (default: ``'L-BFGS-B'``). + + maxiter : int + Maximum number of IRLS iterations to perform before switching to L-BFGS-B. + (default: ``250``). + + Returns + ------- + beta: ndarray + Fitted (basemean, lfc) coefficients of negative binomial GLM. + + mu: ndarray + Means estimated from size factors and beta: + :math:`\mu = s_{ij} \exp(\beta^t X)`. + + H: ndarray + Diagonal of the :math:`W^{1/2} X (X^t W X)^-1 X^t W^{1/2}` + covariance matrix. + + converged: ndarray + Whether IRLS or the optimizer converged. If not and if dimension allows it, + perform grid search. + """ + + @abstractmethod + def alpha_mle( + self, + counts: np.ndarray, + design_matrix: np.ndarray, + mu: np.ndarray, + alpha_hat: np.ndarray, + min_disp: float, + max_disp: float, + prior_disp_var: Optional[float] = None, + cr_reg: bool = True, + prior_reg: bool = False, + optimizer: Literal["BFGS", "L-BFGS-B"] = "L-BFGS-B", + ) -> Tuple[np.ndarray, np.ndarray]: + """Estimate the dispersion parameter of a negative binomial GLM. + + Parameters + ---------- + counts : ndarray + Raw counts. + + design_matrix : ndarray + Design matrix. + + mu : ndarray + Mean estimation for the NB model. + + alpha_hat : ndarray + Initial dispersion estimate. + + min_disp : float + Lower threshold for dispersion parameters. + + max_disp : float + Upper threshold for dispersion parameters. + + prior_disp_var : float + Prior dispersion variance. + + cr_reg : bool + Whether to use Cox-Reid regularization. (default: ``True``). + + prior_reg : bool + Whether to use prior log-residual regularization. (default: ``False``). + + optimizer : str + Optimizing method to use. Accepted values: 'BFGS' or 'L-BFGS-B'. + (default: ``'L-BFGS-B'``). + + Returns + ------- + ndarray + Dispersion estimate. + + ndarray + Whether L-BFGS-B converged. If not, dispersion is estimated + using grid search. + """ + + @abstractmethod + def wald_test( + self, + design_matrix: np.ndarray, + disp: np.ndarray, + lfc: np.ndarray, + mu: np.ndarray, + ridge_factor: np.ndarray, + contrast: np.ndarray, + lfc_null: np.ndarray, + alt_hypothesis: Optional[ + Literal["greaterAbs", "lessAbs", "greater", "less"] + ] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Run Wald test for differential expression. + + Computes Wald statistics, standard error and p-values from + dispersion and LFC estimates. + + Parameters + ---------- + design_matrix : ndarray + Design matrix. + + disp : float + Dispersion estimate. + + lfc : ndarray + Log-fold change estimate (in natural log scale). + + mu : float + Mean estimation for the NB model. + + ridge_factor : ndarray + Regularization factors. + + contrast : ndarray + Vector encoding the contrast that is being tested. + + lfc_null : float + The (log2) log fold change under the null hypothesis. + + alt_hypothesis : str or None + The alternative hypothesis for computing wald p-values. + + Returns + ------- + wald_p_value : ndarray + Estimated p-value. + + wald_statistic : ndarray + Wald statistic. + + wald_se : ndarray + Standard error of the Wald statistic. + """ + + @abstractmethod + def fit_rough_dispersions( + self, normed_counts: np.ndarray, design_matrix: np.ndarray + ) -> np.ndarray: + """'Rough dispersion' estimates from linear model, as per the R code. + + Used as initial estimates in :meth:`DeseqDataSet.fit_genewise_dispersions() + `. + + Parameters + ---------- + normed_counts : ndarray + Array of deseq2-normalized read counts. Rows: samples, columns: genes. + + design_matrix : pandas.DataFrame + A DataFrame with experiment design information (to split cohorts). + Indexed by sample barcodes. Unexpanded, *with* intercept. + + Returns + ------- + ndarray + Estimated dispersion parameter for each gene. + """ + + @abstractmethod + def fit_moments_dispersions( + self, normed_counts: np.ndarray, size_factors: np.ndarray + ) -> np.ndarray: + """Dispersion estimates based on moments, as per the R code. + + Used as initial estimates in :meth:`DeseqDataSet.fit_genewise_dispersions() + `. + + Parameters + ---------- + normed_counts : ndarray + Array of deseq2-normalized read counts. Rows: samples, columns: genes. + + size_factors : ndarray + DESeq2 normalization factors. + + Returns + ------- + ndarray + Estimated dispersion parameter for each gene. + """ + + @abstractmethod + def dispersion_trend_gamma_glm( + self, covariates: pd.Series, targets: pd.Series + ) -> Tuple[np.ndarray, np.ndarray]: + """Fit a gamma glm on gene dispersions. + + The intercept should be concatenated in this method + and the first returned coefficient should be the intercept. + + Parameters + ---------- + covariates : pd.Series + Covariates for the regression (num_genes,). + targets : pd.Series + Targets for the regression (num_genes,). + + Returns + ------- + coeffs : ndarray + Coefficients of the regression. + predictions : ndarray + Predictions of the regression. + """ + + @abstractmethod + def lfc_shrink_nbinom_glm( + self, + design_matrix: np.ndarray, + counts: np.ndarray, + size: np.ndarray, + offset: np.ndarray, + prior_no_shrink_scale: float, + prior_scale: float, + optimizer: str, + shrink_index: int, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Fit a negative binomial MAP LFC using an apeGLM prior. + + Only the LFC is shrinked, and not the intercept. + + Parameters + ---------- + design_matrix : ndarray + Design matrix. + + counts : ndarray + Raw counts. + + size : ndarray + Size parameter of NB family (inverse of dispersion). + + offset : ndarray + Natural logarithm of size factor. + + prior_no_shrink_scale : float + Prior variance for the intercept. + + prior_scale : float + Prior variance for the LFC parameter. + + optimizer : str + Optimizing method to use in case IRLS starts diverging. + Accepted values: 'L-BFGS-B', 'BFGS' or 'Newton-CG'. + + shrink_index : int + Index of the LFC coordinate to shrink. (default: ``1``). + + Returns + ------- + beta: ndarray + 2-element array, containing the intercept (first) and the LFC (second). + + inv_hessian: ndarray + Inverse of the Hessian of the objective at the estimated MAP LFC. + + converged: ndarray + Whether L-BFGS-B converged for each optimization problem. + """ diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index 4ddb5fec..89460450 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -999,9 +999,6 @@ def fit_rough_dispersions( "dispersion. Please use a design with fewer variables." ) - # Exclude genes with all zeroes - normed_counts = normed_counts[:, ~(normed_counts == 0).all(axis=0)] - reg = LinearRegression(fit_intercept=False) reg.fit(design_matrix, normed_counts) y_hat = reg.predict(design_matrix)