From 0d912b9d1fb63a7cec7590298169f2b3ca2e8535 Mon Sep 17 00:00:00 2001 From: Boris MUZELLEC Date: Fri, 16 Feb 2024 10:44:46 +0100 Subject: [PATCH] chore: catch up code state from statsmodels PR --- pydeseq2/dds.py | 86 +++++++++++++++++++++++----------------- tests/test_edge_cases.py | 2 +- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index a98a29f2..879101ba 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -73,7 +73,7 @@ class DeseqDataSet(ad.AnnData): specifying the factor of interest and the reference (control) level against which we're testing, e.g. ``["condition", "A"]``. (default: ``None``). - trend_fit_type : str + disp_function_type : str Either "parametric", "local" or "mean", for the type of fitting of dispersions trend curve.If "parametric" is selected but the fitting fails, it will switch to "local". (default: ``"parametric"``). @@ -188,7 +188,7 @@ def __init__( design_factors: Union[str, List[str]] = "condition", continuous_factors: Optional[List[str]] = None, ref_level: Optional[List[str]] = None, - trend_fit_type: Literal["parametric", "local", "mean"] = "parametric", + disp_function_type: Literal["parametric", "local", "mean"] = "parametric", min_mu: float = 0.5, min_disp: float = 1e-8, max_disp: float = 10.0, @@ -274,7 +274,7 @@ def __init__( # Check that the design matrix has full rank self._check_full_rank_design() - self.trend_fit_type = trend_fit_type + self.disp_function_type = disp_function_type self.min_mu = min_mu self.min_disp = min_disp self.max_disp = np.maximum(max_disp, self.n_obs) @@ -486,7 +486,7 @@ def fit_size_factors( warnings.warn( "Every gene contains at least one zero, " "cannot compute log geometric means. Switching to iterative mode.", - RuntimeWarning, + UserWarning, stacklevel=2, ) self._fit_iterate_size_factors() @@ -578,10 +578,10 @@ def fit_genewise_dispersions(self) -> None: def fit_dispersion_trend(self) -> None: r"""Fit the dispersion trend curve. - Three methods are available, depending on the ``trend_fit_type`` attribute: + Three methods are available, depending on the ``disp_function_type`` attribute: "parametric", "local" and "mean". """ - if self.trend_fit_type == "parametric": + if self.disp_function_type == "parametric": try: self._fit_parametric_trend() except RuntimeError: @@ -591,42 +591,42 @@ def fit_dispersion_trend(self) -> None: UserWarning, stacklevel=2, ) - self.trend_fit_type = "local" + self.disp_function_type = "local" if (self.uns["trend_coeffs"] == 0).any(): warnings.warn( - f"self.trend_fit_type={self.trend_fit_type}, but the " + f"self.disp_function_type={self.disp_function_type}, but the " f"dispersion trend was not well captured by the function: " - f"y = a / x + b. Switchiing to local regression.", + f"y = a / x + b. Switching to local regression.", UserWarning, stacklevel=2, ) - self.trend_fit_type = "local" + self.disp_function_type = "local" del self.uns["trend_coeffs"] - if self.trend_fit_type == "local": + if self.disp_function_type == "local": try: self._fit_local_trend() except (ValueError, RuntimeError): print("Local trend fit did not converge, switching to mean fit.") - self.trend_fit_type = "mean" + self.disp_function_type = "mean" - if self.trend_fit_type == "mean": + if self.disp_function_type == "mean": self._fit_mean_trend() - if self.trend_fit_type not in ["parametric", "local", "mean"]: + if self.disp_function_type not in ["parametric", "local", "mean"]: raise NotImplementedError( - f"Unknown trend_fit_type: {self.trend_fit_type}. " + f"Unknown disp_function_type: {self.disp_function_type}. " "Expected 'parametric', 'local' or 'mean'." ) def disp_function(self, x): """Return the dispersion trend function at x.""" - if self.uns["disp_function_type"] == "parametric": + if self.disp_function_type == "parametric": return dispersion_trend(x, self.uns["trend_coeffs"]) - elif self.uns["disp_function_type"] == "local": + elif self.disp_function_type == "local": return np.exp(self.uns["loess"].predict(np.log(x)).values) - elif self.uns["disp_function_type"] == "mean": + elif self.disp_function_type == "mean": return self.uns["mean_disp"] def fit_dispersion_prior(self) -> None: @@ -890,24 +890,29 @@ def _fit_parametric_trend(self) -> None: old_coeffs = pd.Series([0.1, 0.1]) coeffs = pd.Series([1.0, 1.0]) - while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6: - old_coeffs = coeffs - coeffs, predictions = self.inference.dispersion_trend_gamma_glm( - covariates, targets - ) - # Filter out genes that are too far away from the curve before refitting - pred_ratios = ( - self[:, covariates.index].varm["genewise_dispersions"] / predictions - ) + try: + while (coeffs > 0).all() and ( + np.log(np.abs(coeffs / old_coeffs)) ** 2 + ).sum() >= 1e-6: + old_coeffs = coeffs + coeffs, predictions = self.inference.dispersion_trend_gamma_glm( + covariates, targets + ) + # Filter out genes that are too far away from the curve before refitting + pred_ratios = ( + self[:, covariates.index].varm["genewise_dispersions"] / predictions + ) - targets.drop( - targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, - inplace=True, - ) - covariates.drop( - covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, - inplace=True, - ) + targets.drop( + targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + covariates.drop( + covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index, + inplace=True, + ) + except RuntimeError as e: + raise e end = time.time() @@ -1114,6 +1119,7 @@ def _refit_without_outliers( min_replicates=self.min_replicates, beta_tol=self.beta_tol, inference=self.inference, + disp_function_type=self.disp_function_type, ) # Use the same size factors @@ -1129,8 +1135,7 @@ def _refit_without_outliers( # Note: the trend curve is not refitted. sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0) - sub_dds.uns["disp_function"] = self.uns["disp_function"] - sub_dds.varm["fitted_dispersions"] = self.uns["disp_function"]( + sub_dds.varm["fitted_dispersions"] = self.disp_function( sub_dds.varm["_normed_means"][sub_dds.varm["non_zero"]] ) @@ -1215,6 +1220,13 @@ def objective(p): & self.varm["non_zero"] ] + if len(use_for_mean_genes) == 0: + print( + "No genes have a dispersion above 10 * min_disp in " + "_fit_iterate_size_factors." + ) + break + mean_disp = trimmed_mean( self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001 ) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 5a2b8a71..b93a853e 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -468,7 +468,7 @@ def test_zero_inflated(): counts_df.iloc[idx, :] = 0 dds = DeseqDataSet(counts=counts_df, metadata=metadata) - with pytest.warns(RuntimeWarning): + with pytest.warns(UserWarning): dds.deseq2()