Skip to content

Commit

Permalink
chore: catch up code state from statsmodels PR
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisMuzellec committed Feb 16, 2024
1 parent c151c10 commit 0d912b9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
86 changes: 49 additions & 37 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DeseqDataSet(ad.AnnData):
specifying the factor of interest and the reference (control) level against which
we're testing, e.g. ``["condition", "A"]``. (default: ``None``).
trend_fit_type : str
disp_function_type : str
Either "parametric", "local" or "mean", for the type of fitting of dispersions
trend curve.If "parametric" is selected but the fitting fails, it will switch to
"local". (default: ``"parametric"``).
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(
design_factors: Union[str, List[str]] = "condition",
continuous_factors: Optional[List[str]] = None,
ref_level: Optional[List[str]] = None,
trend_fit_type: Literal["parametric", "local", "mean"] = "parametric",
disp_function_type: Literal["parametric", "local", "mean"] = "parametric",
min_mu: float = 0.5,
min_disp: float = 1e-8,
max_disp: float = 10.0,
Expand Down Expand Up @@ -274,7 +274,7 @@ def __init__(
# Check that the design matrix has full rank
self._check_full_rank_design()

self.trend_fit_type = trend_fit_type
self.disp_function_type = disp_function_type
self.min_mu = min_mu
self.min_disp = min_disp
self.max_disp = np.maximum(max_disp, self.n_obs)
Expand Down Expand Up @@ -486,7 +486,7 @@ def fit_size_factors(
warnings.warn(
"Every gene contains at least one zero, "
"cannot compute log geometric means. Switching to iterative mode.",
RuntimeWarning,
UserWarning,
stacklevel=2,
)
self._fit_iterate_size_factors()
Expand Down Expand Up @@ -578,10 +578,10 @@ def fit_genewise_dispersions(self) -> None:
def fit_dispersion_trend(self) -> None:
r"""Fit the dispersion trend curve.
Three methods are available, depending on the ``trend_fit_type`` attribute:
Three methods are available, depending on the ``disp_function_type`` attribute:
"parametric", "local" and "mean".
"""
if self.trend_fit_type == "parametric":
if self.disp_function_type == "parametric":
try:
self._fit_parametric_trend()
except RuntimeError:
Expand All @@ -591,42 +591,42 @@ def fit_dispersion_trend(self) -> None:
UserWarning,
stacklevel=2,
)
self.trend_fit_type = "local"
self.disp_function_type = "local"

if (self.uns["trend_coeffs"] == 0).any():
warnings.warn(
f"self.trend_fit_type={self.trend_fit_type}, but the "
f"self.disp_function_type={self.disp_function_type}, but the "
f"dispersion trend was not well captured by the function: "
f"y = a / x + b. Switchiing to local regression.",
f"y = a / x + b. Switching to local regression.",
UserWarning,
stacklevel=2,
)
self.trend_fit_type = "local"
self.disp_function_type = "local"
del self.uns["trend_coeffs"]

if self.trend_fit_type == "local":
if self.disp_function_type == "local":
try:
self._fit_local_trend()
except (ValueError, RuntimeError):
print("Local trend fit did not converge, switching to mean fit.")
self.trend_fit_type = "mean"
self.disp_function_type = "mean"

if self.trend_fit_type == "mean":
if self.disp_function_type == "mean":
self._fit_mean_trend()

if self.trend_fit_type not in ["parametric", "local", "mean"]:
if self.disp_function_type not in ["parametric", "local", "mean"]:
raise NotImplementedError(
f"Unknown trend_fit_type: {self.trend_fit_type}. "
f"Unknown disp_function_type: {self.disp_function_type}. "
"Expected 'parametric', 'local' or 'mean'."
)

def disp_function(self, x):
"""Return the dispersion trend function at x."""
if self.uns["disp_function_type"] == "parametric":
if self.disp_function_type == "parametric":
return dispersion_trend(x, self.uns["trend_coeffs"])
elif self.uns["disp_function_type"] == "local":
elif self.disp_function_type == "local":
return np.exp(self.uns["loess"].predict(np.log(x)).values)
elif self.uns["disp_function_type"] == "mean":
elif self.disp_function_type == "mean":
return self.uns["mean_disp"]

def fit_dispersion_prior(self) -> None:
Expand Down Expand Up @@ -890,24 +890,29 @@ def _fit_parametric_trend(self) -> None:
old_coeffs = pd.Series([0.1, 0.1])
coeffs = pd.Series([1.0, 1.0])

while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)
try:
while (coeffs > 0).all() and (
np.log(np.abs(coeffs / old_coeffs)) ** 2
).sum() >= 1e-6:
old_coeffs = coeffs
coeffs, predictions = self.inference.dispersion_trend_gamma_glm(
covariates, targets
)
# Filter out genes that are too far away from the curve before refitting
pred_ratios = (
self[:, covariates.index].varm["genewise_dispersions"] / predictions
)

targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
targets.drop(
targets[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
covariates.drop(
covariates[(pred_ratios < 1e-4) | (pred_ratios >= 15)].index,
inplace=True,
)
except RuntimeError as e:
raise e

end = time.time()

Expand Down Expand Up @@ -1114,6 +1119,7 @@ def _refit_without_outliers(
min_replicates=self.min_replicates,
beta_tol=self.beta_tol,
inference=self.inference,
disp_function_type=self.disp_function_type,
)

# Use the same size factors
Expand All @@ -1129,8 +1135,7 @@ def _refit_without_outliers(
# Note: the trend curve is not refitted.
sub_dds.varm["_normed_means"] = sub_dds.layers["normed_counts"].mean(0)

sub_dds.uns["disp_function"] = self.uns["disp_function"]
sub_dds.varm["fitted_dispersions"] = self.uns["disp_function"](
sub_dds.varm["fitted_dispersions"] = self.disp_function(
sub_dds.varm["_normed_means"][sub_dds.varm["non_zero"]]
)

Expand Down Expand Up @@ -1215,6 +1220,13 @@ def objective(p):
& self.varm["non_zero"]
]

if len(use_for_mean_genes) == 0:
print(
"No genes have a dispersion above 10 * min_disp in "
"_fit_iterate_size_factors."
)
break

mean_disp = trimmed_mean(
self[:, use_for_mean_genes].varm["genewise_dispersions"], trim=0.001
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_zero_inflated():
counts_df.iloc[idx, :] = 0

dds = DeseqDataSet(counts=counts_df, metadata=metadata)
with pytest.warns(RuntimeWarning):
with pytest.warns(UserWarning):
dds.deseq2()


Expand Down

0 comments on commit 0d912b9

Please sign in to comment.