Skip to content

Commit

Permalink
profiles: provide the option to plot w/stderr
Browse files Browse the repository at this point in the history
  • Loading branch information
JoepVanlier committed Dec 13, 2024
1 parent f0ad9df commit cd12f63
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
3 changes: 2 additions & 1 deletion lumicks/pylake/fitting/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
fixed=False,
shared=False,
unit=None,
stderr=None,
):
"""Model parameter
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
from the data. See also: :meth:`~lumicks.pylake.FdFit.profile_likelihood()`.
"""

self.stderr = None
self.stderr = stderr
"""Standard error of this parameter.
Standard errors are calculated after fitting the model. These asymptotic errors are based
Expand Down
12 changes: 11 additions & 1 deletion lumicks/pylake/fitting/profile_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,17 +639,27 @@ def chi2(self):
def p(self):
return self.parameters[:, self.profile_info.profiled_parameter_index]

def plot(self, *, significance_level=None, **kwargs):
def plot(self, *, significance_level=None, std_err=None, **kwargs):
"""Plot profile likelihood
Parameters
----------
significance_level : float, optional
Desired significance level (resulting in a 100 * (1 - alpha)% confidence interval) to
plot. Default is the significance level specified when the profile was generated.
std_err : float | None
If provided, also make a quadratic plot based on a standard error.
"""
import matplotlib.pyplot as plt

if std_err:
x = np.arange(-3 * std_err, 3 * std_err, 0.1 * std_err)
plt.plot(
self.p[np.argmin(self.chi2)] + x,
self.profile_info.minimum_chi2 + x**2 / (2 * std_err**2),
"k--",
)

dash_length = 5
plt.plot(self.p, self.chi2, **kwargs)

Expand Down
24 changes: 18 additions & 6 deletions lumicks/pylake/population/dwelltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def fit_func(params, lb, ub, fitted):
)
parameters = Params(
**{
key: Parameter(param, lower_bound=lb, upper_bound=ub)
for key, param, (lb, ub) in zip(keys, dwelltime_model._parameters, bounds)
key: Parameter(param, lower_bound=lb, upper_bound=ub, stderr=std_err)
for key, param, (lb, ub), std_err in zip(
keys, dwelltime_model._parameters, bounds, dwelltime_model._std_errs
)
}
)

Expand Down Expand Up @@ -204,7 +206,7 @@ def n_components(self):
"""Number of components in the model."""
return self.model.n_components

def plot(self, alpha=None):
def plot(self, alpha=None, **kwargs):
"""Plot the profile likelihoods for the parameters of a model.
Confidence interval is indicated by the region where the profile crosses the chi squared
Expand All @@ -219,13 +221,23 @@ def plot(self, alpha=None):
"""
import matplotlib.pyplot as plt

with_stderr = kwargs.pop("with_stderr") if "with_stderr" in kwargs else False

std_errs = self.model._std_errs[~np.isnan(self.model._std_errs)]
if self.n_components == 1:
next(iter(self.profiles.values())).plot(significance_level=alpha)
next(iter(self.profiles.values())).plot(
significance_level=alpha,
std_err=std_errs[0] if with_stderr else None,
)
else:
plot_idx = np.reshape(np.arange(1, len(self.profiles) + 1), (-1, 2)).T.flatten()
for idx, profile in zip(plot_idx, self.profiles.values()):
for par_idx, (idx, profile) in enumerate(zip(plot_idx, self.profiles.values())):
plt.subplot(self.n_components, 2, idx)
profile.plot(significance_level=alpha)
profile.plot(
significance_level=alpha,
std_err=std_errs[par_idx] if with_stderr else None,
**kwargs,
)


@dataclass(frozen=True)
Expand Down

0 comments on commit cd12f63

Please sign in to comment.