Skip to content

Commit

Permalink
dwelltime: handle numerical corner cases
Browse files Browse the repository at this point in the history
- make sure we don't differentiate over the constraint edge when calculating standard errors
- ensure that we never perform 0/0 when calculating Jacobians, which would result in unnecessary warnings
  • Loading branch information
JoepVanlier committed Dec 16, 2024
1 parent 1b9b020 commit de02751
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 18 deletions.
34 changes: 26 additions & 8 deletions lumicks/pylake/population/dwelltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,10 +1071,21 @@ def _exponential_mixture_log_likelihood_jacobian(params, t, t_min, t_max, t_step
# The derivative of logsumexp is given by: sum(exp(fi(x)) dfi(x)/dx) / sum(exp(fi(x)))
total_denom = np.exp(scipy.special.logsumexp(components, axis=0))
sum_components = np.sum(np.exp(components), axis=0)
dtotal_damp = (sum_components * dlognorm_damp + np.exp(components) * dlogamp_damp) / total_denom
dtotal_dtau = (
sum_components * dlognorm_dtau + np.exp(components) * dlogtauterm_dtau
) / total_denom

dtotal_damp = sum_components * dlognorm_damp + np.exp(components) * dlogamp_damp
dtotal_dtau = sum_components * dlognorm_dtau + np.exp(components) * dlogtauterm_dtau

def safe_divide(arr, denominator):
"""We only want to divide components that contribute to the output, otherwise we can
get numerical issues with divisions by zero. Note that we are _not_ dividing by the
denominator here, but by the nominator. We do not want to silence any divisions of non-zero
values by zero which are invalid and should be caught."""
for idx in range(arr.shape[0]):
mask = abs(arr[idx, :]) > 0
arr[idx, mask] = arr[idx, mask] / denominator[mask]

safe_divide(dtotal_damp, total_denom)
safe_divide(dtotal_dtau, total_denom)
unsummed_gradient = np.vstack((dtotal_damp, dtotal_dtau))

return -np.sum(unsummed_gradient, axis=1)
Expand Down Expand Up @@ -1264,6 +1275,8 @@ def _handle_amplitude_constraint(

def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time):
return (
# Note: the standard error computation relies on the lower bound on the amplitude as it
# keeps the amplitude from going negative.
*[(1e-9, 1.0 - 1e-9) for _ in range(n_components)],
*[
(
Expand All @@ -1276,7 +1289,9 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_


def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask):
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6)
# The minimum bound on amplitudes is 1e-9, by making the max step 1e-10, we ensure that
# we never hit a singularity here.
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-10)

if constraints:
from scipy.linalg import null_space
Expand Down Expand Up @@ -1418,9 +1433,12 @@ def jac_fun(params):

std_errs = np.full(current_params.shape, np.nan)
if use_jacobian:
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
try:
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
except np.linalg.linalg.LinAlgError:
pass # We silence these until the standard error API is publicly available

return current_params, -result.fun, std_errs

Expand Down
21 changes: 11 additions & 10 deletions lumicks/pylake/population/tests/test_dwelltimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,17 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte

@pytest.mark.parametrize(
# fmt:off
"exp_name, n_components, ref_std_errs",
"exp_name, n_components, ref_std_errs, tolerance",
[
("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted!
("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]),
("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]),
("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]),
("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]),
("dataset_2exp", 1, [np.nan, 0.117634], 1e-4), # Amplitude is not fitted!
("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388], 1e-4),
("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355], 1e-4),
# Over-fitted, hence coarse tolerances
("dataset_2exp_discrete", 3, [0.0976, 0.377, 0.39, 0.25, 1.22, 4.46], 1e-1),
("dataset_2exp_discrete", 4, [0.0976, 0.000036, 0.374, 0.389, 0.252, 1.56, 1.22, 4.43], 1e-1),
]
)
def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs):
def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs, tolerance):
dataset = exponential_data[exp_name]

fit = DwelltimeModel(
Expand All @@ -274,9 +275,9 @@ def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs):
**dataset["parameters"].observation_limits,
discretization_timestep=dataset["parameters"].dt,
)
np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4)
np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=1e-4)
np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=1e-4)
np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=tolerance)
np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=tolerance)
np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=tolerance)


@pytest.mark.parametrize("n_components", [2, 1])
Expand Down

0 comments on commit de02751

Please sign in to comment.