diff --git a/examples/ex_gaussian_mixture.py b/examples/ex_gaussian_mixture.py index 985bfa53..9f3af3bb 100644 --- a/examples/ex_gaussian_mixture.py +++ b/examples/ex_gaussian_mixture.py @@ -1,19 +1,16 @@ +import matplotlib.pyplot as plt +import numpy as np + from navlie.batch.gaussian_mixtures import ( GaussianMixtureResidual, + HessianSumMixtureResidual, MaxMixtureResidual, - SumMixtureResidual, MaxSumMixtureResidual, - HessianSumMixtureResidual, + SumMixtureResidual, ) +from navlie.batch.problem import Problem from navlie.batch.residuals import PriorResidual - from navlie.lib.states import VectorState -import os -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns -from pathlib import Path -from navlie.batch.problem import Problem def main(): @@ -36,7 +33,7 @@ def main(): "Sum-Mixture": SumMixtureResidual(component_residuals, weights), "Max-Sum-Mixture": MaxSumMixtureResidual(component_residuals, weights, 10), "Hessian-Sum-Mixture": HessianSumMixtureResidual( - component_residuals, weights, True, 0.1 + component_residuals, weights, True ), } diff --git a/navlie/batch/gaussian_mixtures.py b/navlie/batch/gaussian_mixtures.py index 7597b819..e56a3907 100644 --- a/navlie/batch/gaussian_mixtures.py +++ b/navlie/batch/gaussian_mixtures.py @@ -3,6 +3,7 @@ from navlie import State from navlie.batch.residuals import Residual from abc import ABC, abstractmethod +from typing import Dict class GaussianMixtureResidual(Residual, ABC): @@ -60,7 +61,7 @@ def mix_errors( self, error_value_list: List[np.ndarray], sqrt_info_matrix_list: List[np.ndarray], - ) -> Tuple[np.ndarray, List[np.ndarray]]: + ) -> Tuple[np.ndarray]: """Each mixture must implement this method.. Compute the factor error from the errors corresponding to each component @@ -79,7 +80,8 @@ def mix_jacobians( error_value_list: List[np.ndarray], jacobian_list_of_lists: List[List[np.ndarray]], sqrt_info_matrix_list: List[np.ndarray], - ) -> Tuple[np.ndarray, List[np.ndarray]]: + reused_values: Dict = None, + ) -> Tuple[List[np.ndarray]]: """Each mixture must implement this method. For every state, compute Jacobian of the Gaussian mixture w.r.t. that state @@ -144,6 +146,11 @@ def evaluate_component_residuals( sqrt_info_matrix_list.append(error.sqrt_info_matrix(cur_states)) self.sqrt_info_matrix_list = sqrt_info_matrix_list + # For the not NLS-compatible HSM version, these values need to be reused for Hessian computation. + self.error_value_list_cache = error_value_list + self.jacobian_list_of_lists_cache = jacobian_list_of_lists + self.sqrt_info_matrix_list_cache = sqrt_info_matrix_list + return error_value_list, jacobian_list_of_lists, sqrt_info_matrix_list def evaluate( @@ -156,10 +163,13 @@ def evaluate( jacobian_list_of_lists, sqrt_info_matrix_list, ) = self.evaluate_component_residuals(states, compute_jacobians) - e = self.mix_errors(error_value_list, sqrt_info_matrix_list) + e, reused_values = self.mix_errors(error_value_list, sqrt_info_matrix_list) if compute_jacobians: jac_list = self.mix_jacobians( - error_value_list, jacobian_list_of_lists, sqrt_info_matrix_list + error_value_list, + jacobian_list_of_lists, + sqrt_info_matrix_list, + reused_values, ) return e, jac_list return e @@ -184,11 +194,12 @@ def mix_errors( self, error_value_list: List[np.ndarray], sqrt_info_matrix_list: List[np.ndarray], - ) -> Tuple[np.ndarray, List[np.ndarray]]: + ) -> Tuple[np.ndarray, Dict]: alphas = [ weight * np.linalg.det(sqrt_info_matrix) for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) ] + # Maximum component obtained as # K = argmax alpha_k exp(-0.5 e^\trans e) # = argmin -2* log alpha_k + e^\trans e @@ -207,25 +218,32 @@ def mix_errors( nonlinear_part = np.array(np.log(alpha_max / alpha_k)).reshape(-1) nonlinear_part = np.sqrt(2) * np.sqrt(nonlinear_part) e_mix = np.concatenate([linear_part, nonlinear_part]) - return e_mix + + reused_values = {"alphas": alphas, "res_values": res_values} + + return e_mix, reused_values def mix_jacobians( self, error_value_list: List[np.ndarray], jacobian_list_of_lists: List[List[np.ndarray]], sqrt_info_matrix_list: List[np.ndarray], + reused_values: Dict = None, ) -> Tuple[np.ndarray, List[np.ndarray]]: - alphas = [ - weight * np.linalg.det(sqrt_info_matrix) - for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) - ] - - res_values = np.array( - [ - -np.log(alpha) + 0.5 * e.T @ e - for alpha, e in zip(alphas, error_value_list) + if reused_values is not None: + alphas = reused_values["alphas"] + res_values = reused_values["res_values"] + else: + alphas = [ + weight * np.linalg.det(sqrt_info_matrix) + for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) ] - ) + res_values = np.array( + [ + -np.log(alpha) + 0.5 * e.T @ e + for alpha, e in zip(alphas, error_value_list) + ] + ) dominant_idx = np.argmin(res_values) jac_list_linear_part: List[np.ndarray] = jacobian_list_of_lists[dominant_idx] @@ -290,7 +308,14 @@ def mix_errors( nonlinear_part = self.compute_nonlinear_part(scalar_errors_differences, alphas) e_mix = np.concatenate([linear_part, nonlinear_part]) - return e_mix + reused_values = { + "alphas": alphas, + "nonlinear_part": nonlinear_part, + "scalar_errors_differences": scalar_errors_differences, + "res_values": res_values, + } + + return e_mix, reused_values def compute_nonlinear_part( self, scalar_errors_differences: List[np.ndarray], alphas: List[float] @@ -317,36 +342,54 @@ def mix_jacobians( error_value_list: List[np.ndarray], jacobian_list_of_lists: List[List[np.ndarray]], sqrt_info_matrix_list: List[np.ndarray], + reused_values: Dict = None, ) -> Tuple[np.ndarray, List[np.ndarray]]: n_state_list = len(jacobian_list_of_lists[0]) - alphas = [ - weight * np.linalg.det(sqrt_info_matrix) - for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) - ] - - # LINEAR PART - res_values = np.array( - [ - -np.log(alpha) + 0.5 * e.T @ e - for alpha, e in zip(alphas, error_value_list) + if reused_values is not None: + alphas = reused_values["alphas"] + scalar_errors_differences = reused_values["scalar_errors_differences"] + e_nl = reused_values["nonlinear_part"] + res_values = reused_values["res_values"] + else: + alphas = [ + weight * np.linalg.det(sqrt_info_matrix) + for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) ] - ) + # LINEAR PART + res_values = np.array( + [ + -np.log(alpha) + 0.5 * e.T @ e + for alpha, e in zip(alphas, error_value_list) + ] + ) + err_kmax = error_value_list[dominant_idx] + scalar_errors_differences = [ + -0.5 * e.T @ e + 0.5 * err_kmax.T @ err_kmax for e in error_value_list + ] + + # NONLINEAR PART + # Compute error + e_nl = self.compute_nonlinear_part(scalar_errors_differences, alphas) + dominant_idx = np.argmin(res_values) + err_kmax = error_value_list[dominant_idx] jac_list_linear_part: List[np.ndarray] = jacobian_list_of_lists[dominant_idx] + # Loop through every state to compute Jacobian with respect to it. + jac_list_nl = [] - err_kmax = error_value_list[dominant_idx] + sum_exp = np.sum( + [ + alpha * np.exp(delta) + for alpha, delta in zip(alphas, scalar_errors_differences) + ] + ) - scalar_errors_differences = [ - -0.5 * e.T @ e + 0.5 * err_kmax.T @ err_kmax for e in error_value_list + drho_df_list = [ + alpha * np.exp(delta) / sum_exp + for alpha, delta in zip(alphas, scalar_errors_differences) ] - # NONLINEAR PART - # Compute error - e_nl = self.compute_nonlinear_part(scalar_errors_differences, alphas) - - # Loop through every state to compute Jacobian with respect to it. - jac_list_nl = [] for lv1 in range(n_state_list): jacobian_list_components_wrt_cur_state = [ jac_list[lv1] for jac_list in jacobian_list_of_lists @@ -355,33 +398,20 @@ def mix_jacobians( jac_dom = jacobian_list_components_wrt_cur_state[dominant_idx] n_x = jacobian_list_components_wrt_cur_state[0].shape[1] numerator = np.zeros((1, n_x)) - denominator = 0.0 + numerator_list = [ - -alpha - * np.exp(scal_err) - * ( - e_k.reshape(1, -1) @ -jac_e_i - + err_kmax.reshape(1, -1) @ jac_dom - ) - for alpha, scal_err, e_k, jac_e_i in zip( - alphas, - scalar_errors_differences, + -drho * (e_k.reshape(1, -1) @ -jac_e_i) + for e_k, jac_e_i, drho in zip( error_value_list, jacobian_list_components_wrt_cur_state, + drho_df_list, ) ] - denominator_list = [ - alpha * np.exp(scal_err) - for alpha, scal_err in zip(alphas, scalar_errors_differences) - ] for term in numerator_list: numerator += term - - for term in denominator_list: - denominator += term - denominator = denominator * e_nl - jac_list_nl.append(numerator / denominator) + numerator -= err_kmax.reshape(1, -1) @ jac_dom + jac_list_nl.append(numerator / e_nl) else: jac_list_nl.append(None) @@ -445,7 +475,8 @@ def mix_errors( ) ) e = np.sqrt(2) * np.sqrt(normalization_const + scalar_errors[kmax] - sum_term) - return e + reused_values = {"alphas": alphas, "scalar_errors": scalar_errors, "e_sm": e} + return e, reused_values def mix_jacobians( self, @@ -454,13 +485,20 @@ def mix_jacobians( List[np.ndarray] ], # outer list is components, inner list states sqrt_info_matrix_list: List[np.ndarray], + reused_values: Dict = None, ) -> Tuple[np.ndarray, List[np.ndarray]]: + if reused_values is not None: + alpha_list = reused_values["alphas"] + e_sm = reused_values["e_sm"] + else: + alpha_list = [ + weight * np.linalg.det(sqrt_info_matrix) + for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) + ] + e_sm, _ = self.mix_errors(error_value_list, sqrt_info_matrix_list) + n_state_list = len(jacobian_list_of_lists[0]) - alpha_list = [ - weight * np.linalg.det(sqrt_info_matrix) - for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) - ] - e_sm = self.mix_errors(error_value_list, sqrt_info_matrix_list) + error_value_list = [e.reshape(-1, 1) for e in error_value_list] eTe_list = [e.T @ e for e in error_value_list] eTe_dom = min(eTe_list) @@ -521,21 +559,26 @@ class HessianSumMixtureResidual(GaussianMixtureResidual): } """ - sum_mixture_residual: SumMixtureResidual no_use_complex_numbers: bool - normalization_constant: float def __init__( self, errors: List[Residual], weights, no_use_complex_numbers=True, - normalization_constant=0.1, ): super().__init__(errors, weights) self.sum_mixture_residual = SumMixtureResidual(errors, weights) self.no_use_complex_numbers = no_use_complex_numbers - self.normalization_constant = normalization_constant + + @staticmethod + def get_normalization_constant(alphas: List[float]): + alpha_sum = np.sum(alphas) + log_sum = 0.0 + for lv1 in range(len(alphas)): + log_sum = log_sum + alphas[lv1] * np.exp(alpha_sum / alphas[lv1]) + + return np.log(log_sum) def mix_errors( self, @@ -546,33 +589,30 @@ def mix_errors( weight * np.linalg.det(sqrt_info_matrix) for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) ] - error_value_list = [e.reshape(-1, 1) for e in error_value_list] - eTe_list = [e.T @ e for e in error_value_list] - - # Normalize all the exponent arguments to avoid numerical issues. - eTe_dom = min(eTe_list) + normalization_constant = self.get_normalization_constant(alpha_list) + f_list = [0.5 * np.sum(e**2) for e in error_value_list] + kmax = np.argmin(np.array(f_list)) + f_kmax = f_list[kmax] sum_exp = np.sum( - [ - alpha * np.exp(0.5 * eTe_dom - 0.5 * e.T @ e) - for alpha, e in zip(alpha_list, error_value_list) - ] + [alpha * np.exp(f_kmax - f) for alpha, f in zip(alpha_list, f_list)] ) + drho_df_list = [ - alpha * np.exp(0.5 * eTe_dom - 0.5 * eTe) / sum_exp - for alpha, eTe in zip(alpha_list, eTe_list) + alpha * np.exp(f_kmax - f) / sum_exp for alpha, f in zip(alpha_list, f_list) ] hsm_error = np.vstack( - [np.sqrt(drho) * e for drho, e in zip(drho_df_list, error_value_list)] + [ + np.sqrt(drho) * e.reshape(-1, 1) + for drho, e in zip(drho_df_list, error_value_list) + ] ).squeeze() - desired_loss = np.sum( - self.sum_mixture_residual.mix_errors( - error_value_list, sqrt_info_matrix_list - ) - ** 2 - ) + # When the loss is computed at the end, it is computed as 1/2 * e^\trans e. + # The normalization constant is a bound on 2*logsumexp minus the norm of hsm_error. + # This works out to at the end evaluate normalization_constant + f_kmax - np.log(sum_exp). + desired_loss = 2 * (normalization_constant + f_kmax - np.log(sum_exp)) if not self.no_use_complex_numbers: current_loss = np.sum(hsm_error**2) @@ -585,11 +625,7 @@ def mix_errors( ) if self.no_use_complex_numbers: current_loss = np.sum(hsm_error**2) - - delta = self.normalization_constant + desired_loss - current_loss - if delta < 0: - self.normalization_constant = delta + 1 - delta = self.normalization_constant + desired_loss - current_loss + delta = desired_loss - current_loss diff = np.array(np.sqrt(delta)) hsm_error = np.concatenate( @@ -598,37 +634,47 @@ def mix_errors( np.atleast_1d(np.array(diff)), ] ) - return hsm_error + reused_values = { + "alphas": alpha_list, + "f_list": f_list, + "sum_exp": sum_exp, + "normalization_constant": normalization_constant, + "sum_exp": sum_exp, + "drho_df_list": drho_df_list, + } + return hsm_error, reused_values def mix_jacobians( self, error_value_list: List[np.ndarray], jacobian_list_of_lists: List[List[np.ndarray]], sqrt_info_matrix_list: List[np.ndarray], + reused_values: Dict = None, ) -> List[np.ndarray]: n_state_list = len(jacobian_list_of_lists[0]) + if reused_values is not None: + drho_df_list = reused_values["drho_df_list"] + else: + alpha_list = [ + weight * np.linalg.det(sqrt_info_matrix) + for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) + ] + error_value_list = [e.reshape(-1, 1) for e in error_value_list] + eTe_list = [e.T @ e for e in error_value_list] - alpha_list = [ - weight * np.linalg.det(sqrt_info_matrix) - for weight, sqrt_info_matrix in zip(self.weights, sqrt_info_matrix_list) - ] - error_value_list = [e.reshape(-1, 1) for e in error_value_list] - eTe_list = [e.T @ e for e in error_value_list] + # Normalize all the exponent arguments to avoid numerical issues. + eTe_dom = min(eTe_list) + sum_exp = np.sum( + [ + alpha * np.exp(0.5 * eTe_dom - 0.5 * e.T @ e) + for alpha, e in zip(alpha_list, error_value_list) + ] + ) - # Normalize all the exponent arguments to avoid numerical issues. - eTe_dom = min(eTe_list) - exp_list = [np.exp(0.5 * eTe_dom - 0.5 * e.T @ e) for e in error_value_list] - sum_exp = np.sum( - [ - alpha * np.exp(0.5 * eTe_dom - 0.5 * e.T @ e) - for alpha, e in zip(alpha_list, error_value_list) + drho_df_list = [ + alpha * np.exp(0.5 * eTe_dom - 0.5 * eTe) / sum_exp + for alpha, eTe in zip(alpha_list, eTe_list) ] - ) - - drho_df_list = [ - alpha * np.exp(0.5 * eTe_dom - 0.5 * eTe) / sum_exp - for alpha, eTe in zip(alpha_list, eTe_list) - ] jac_list = [] for lv1 in range(n_state_list):