From de02751fd74de0cefa3a3f7558fa5f0300e4cfc7 Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Mon, 16 Dec 2024 11:49:27 +0100 Subject: [PATCH] dwelltime: handle numerical corner cases - make sure we don't differentiate over the constraint edge when calculating standard errors - ensure that we never perform 0/0 when calculating Jacobians, which would result in unnecessary warnings --- lumicks/pylake/population/dwelltime.py | 34 ++++++++++++++----- .../population/tests/test_dwelltimes.py | 21 ++++++------ 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lumicks/pylake/population/dwelltime.py b/lumicks/pylake/population/dwelltime.py index aa82c2f93..adb359d79 100644 --- a/lumicks/pylake/population/dwelltime.py +++ b/lumicks/pylake/population/dwelltime.py @@ -1071,10 +1071,21 @@ def _exponential_mixture_log_likelihood_jacobian(params, t, t_min, t_max, t_step # The derivative of logsumexp is given by: sum(exp(fi(x)) dfi(x)/dx) / sum(exp(fi(x))) total_denom = np.exp(scipy.special.logsumexp(components, axis=0)) sum_components = np.sum(np.exp(components), axis=0) - dtotal_damp = (sum_components * dlognorm_damp + np.exp(components) * dlogamp_damp) / total_denom - dtotal_dtau = ( - sum_components * dlognorm_dtau + np.exp(components) * dlogtauterm_dtau - ) / total_denom + + dtotal_damp = sum_components * dlognorm_damp + np.exp(components) * dlogamp_damp + dtotal_dtau = sum_components * dlognorm_dtau + np.exp(components) * dlogtauterm_dtau + + def safe_divide(arr, denominator): + """We only want to divide components that contribute to the output, otherwise we can + get numerical issues with divisions by zero. Note that we are _not_ dividing by the + denominator here, but by the nominator. We do not want to silence any divisions of non-zero + values by zero which are invalid and should be caught.""" + for idx in range(arr.shape[0]): + mask = abs(arr[idx, :]) > 0 + arr[idx, mask] = arr[idx, mask] / denominator[mask] + + safe_divide(dtotal_damp, total_denom) + safe_divide(dtotal_dtau, total_denom) unsummed_gradient = np.vstack((dtotal_damp, dtotal_dtau)) return -np.sum(unsummed_gradient, axis=1) @@ -1264,6 +1275,8 @@ def _handle_amplitude_constraint( def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time): return ( + # Note: the standard error computation relies on the lower bound on the amplitude as it + # keeps the amplitude from going negative. *[(1e-9, 1.0 - 1e-9) for _ in range(n_components)], *[ ( @@ -1276,7 +1289,9 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_ def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask): - hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6) + # The minimum bound on amplitudes is 1e-9, by making the max step 1e-10, we ensure that + # we never hit a singularity here. + hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-10) if constraints: from scipy.linalg import null_space @@ -1418,9 +1433,12 @@ def jac_fun(params): std_errs = np.full(current_params.shape, np.nan) if use_jacobian: - std_errs[fitted_param_mask] = _calculate_std_errs( - jac_fun, constraints, num_free_amps, current_params, fitted_param_mask - ) + try: + std_errs[fitted_param_mask] = _calculate_std_errs( + jac_fun, constraints, num_free_amps, current_params, fitted_param_mask + ) + except np.linalg.linalg.LinAlgError: + pass # We silence these until the standard error API is publicly available return current_params, -result.fun, std_errs diff --git a/lumicks/pylake/population/tests/test_dwelltimes.py b/lumicks/pylake/population/tests/test_dwelltimes.py index 37538c52a..370ab5e8a 100644 --- a/lumicks/pylake/population/tests/test_dwelltimes.py +++ b/lumicks/pylake/population/tests/test_dwelltimes.py @@ -256,16 +256,17 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte @pytest.mark.parametrize( # fmt:off - "exp_name, n_components, ref_std_errs", + "exp_name, n_components, ref_std_errs, tolerance", [ - ("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted! - ("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]), - ("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]), - ("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]), - ("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]), + ("dataset_2exp", 1, [np.nan, 0.117634], 1e-4), # Amplitude is not fitted! + ("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388], 1e-4), + ("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355], 1e-4), + # Over-fitted, hence coarse tolerances + ("dataset_2exp_discrete", 3, [0.0976, 0.377, 0.39, 0.25, 1.22, 4.46], 1e-1), + ("dataset_2exp_discrete", 4, [0.0976, 0.000036, 0.374, 0.389, 0.252, 1.56, 1.22, 4.43], 1e-1), ] ) -def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs): +def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs, tolerance): dataset = exponential_data[exp_name] fit = DwelltimeModel( @@ -274,9 +275,9 @@ def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs): **dataset["parameters"].observation_limits, discretization_timestep=dataset["parameters"].dt, ) - np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4) - np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=1e-4) - np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=1e-4) + np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=tolerance) + np.testing.assert_allclose(fit._err_amplitudes, ref_std_errs[:n_components], rtol=tolerance) + np.testing.assert_allclose(fit._err_lifetimes, ref_std_errs[n_components:], rtol=tolerance) @pytest.mark.parametrize("n_components", [2, 1])