diff --git a/.gitignore b/.gitignore index 3817d9f3..1297b254 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ dist/ *.xyz /checkpoints *.model + +.history \ No newline at end of file diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..24555784 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -37,6 +37,7 @@ EnergyDipolesMACE, ScaleShiftBOTNet, ScaleShiftMACE, + LLPRModel, ) from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .symmetric_contraction import SymmetricContraction diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab43..c278486e 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -30,11 +30,16 @@ from .utils import ( compute_fixed_charge_dipole, compute_forces, + compute_ll_feat_gradients, + get_conditional_huber_force_mask, get_edge_vectors_and_lengths, + get_huber_mask, get_outputs, get_symmetric_displacement, + ) +from torch.utils.data import DataLoader # pylint: disable=C0302 @@ -314,6 +319,20 @@ def forward( } +def readout_is_linear(obj: Any): + if isinstance(obj, torch.jit.RecursiveScriptModule): + return obj.original_name == "LinearReadoutBlock" + else: + return isinstance(obj, LinearReadoutBlock) + + +def readout_is_nonlinear(obj: Any): + if isinstance(obj, torch.jit.RecursiveScriptModule): + return obj.original_name == "NonLinearReadoutBlock" + else: + return isinstance(obj, NonLinearReadoutBlock) + + @compile_mode("script") class ScaleShiftMACE(MACE): def __init__( @@ -1107,3 +1126,273 @@ def forward( "atomic_dipoles": atomic_dipoles, } return output + + +@compile_mode("script") +class LLPRModel(torch.nn.Module): + def __init__( + self, + model: Union[MACE, ScaleShiftMACE], + ll_feat_format: str = "avg", + ): + super().__init__() + self.orig_model = model + + # determine ll_feat size from readout layers + self.hidden_sizes_before_readout = [] + self.hidden_sizes = [] + self.readouts_are_linear = [] + self.hidden_size_sum = 0 + for readout in self.orig_model.readouts.children(): + if readout_is_linear(readout): + self.readouts_are_linear.append(True) + cur_size_before_readout = sum(irrep.dim for irrep in o3.Irreps(readout.linear.irreps_in)) + self.hidden_sizes_before_readout.append(cur_size_before_readout) + cur_size = o3.Irreps(readout.linear.irreps_in)[0].dim + self.hidden_sizes.append(cur_size) + self.hidden_size_sum += cur_size + elif readout_is_nonlinear(readout): + self.readouts_are_linear.append(False) + cur_size_before_readout = sum(irrep.dim for irrep in o3.Irreps(readout.linear_1.irreps_in)) + self.hidden_sizes_before_readout.append(cur_size_before_readout) + cur_size = o3.Irreps(readout.linear_2.irreps_in).dim + self.hidden_sizes.append(cur_size) + self.hidden_size_sum += cur_size + else: + raise TypeError("Unknown readout block type for LLPR at initialization!") + + # initialize (inv_)covariance matrices + self.register_buffer("covariance", + torch.zeros((self.hidden_size_sum, self.hidden_size_sum), + device=next(self.orig_model.parameters()).device + ) + ) + self.register_buffer("inv_covariance", + torch.zeros((self.hidden_size_sum, self.hidden_size_sum), + device=next(self.orig_model.parameters()).device + ) + ) + + # extra params associated with LLPR + self.ll_feat_format = ll_feat_format + self.covariance_computed = False + self.covariance_gradients_computed = False + self.inv_covariance_computed = False + + def forward( + self, + data: Dict[str, torch.Tensor], + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_energy_uncertainty: bool = True, + compute_force_uncertainty: bool = False, + compute_virial_uncertainty: bool = False, + compute_stress_uncertainty: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + if compute_force_uncertainty and not compute_force: + raise RuntimeError("Cannot compute force uncertainty without computing forces") + if compute_virial_uncertainty and not compute_virials: + raise RuntimeError("Cannot compute virial uncertainty without computing virials") + if compute_stress_uncertainty and not compute_stress: + raise RuntimeError("Cannot compute stress uncertainty without computing stress") + + num_graphs = data["ptr"].numel() - 1 + + output = self.orig_model( + data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement + ) + ll_feats = self.aggregate_ll_features(output["node_feats"], data["batch"], num_graphs) + + energy_uncertainty = None + force_uncertainty = None + virial_uncertainty = None + stress_uncertainty = None + if self.inv_covariance_computed: + if compute_force_uncertainty or compute_virial_uncertainty or compute_stress_uncertainty: + f_grads, v_grads, s_grads = compute_ll_feat_gradients( + ll_feats=ll_feats, + displacement=output["displacement"], + batch_dict=data.to_dict(), + compute_force=compute_force_uncertainty, + compute_virials=compute_virial_uncertainty, + compute_stress=compute_stress_uncertainty, + ) + else: + f_grads, v_grads, s_grads = None, None, None + ll_feats = ll_feats.detach() + if compute_energy_uncertainty: + energy_uncertainty = torch.einsum("ij, jk, ik -> i", + ll_feats, + self.inv_covariance, + ll_feats + ) + if compute_force_uncertainty: + force_uncertainty = torch.einsum("iaj, jk, iak -> ia", + f_grads, + self.inv_covariance, + f_grads + ) + if compute_virial_uncertainty: + virial_uncertainty = torch.einsum("iabj, jk, iabk -> iab", + v_grads, + self.inv_covariance, + v_grads + ) + if compute_stress_uncertainty: + stress_uncertainty = torch.einsum("iabj, jk, iabk -> iab", + s_grads, + self.inv_covariance, + s_grads + ) + + output["energy_uncertainty"] = energy_uncertainty + output["force_uncertainty"] = force_uncertainty + output["virial_uncertainty"] = virial_uncertainty + output["stress_uncertainty"] = stress_uncertainty + + return output + + def aggregate_ll_features( + self, + ll_feats: torch.Tensor, + indices: torch.Tensor, + num_graphs: int + ) -> torch.Tensor: + # Aggregates (sums) node features over each structure + ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1) + ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)] + + # Aggregate node features + ll_feats_cat = torch.cat(ll_feats_list, dim=-1) + ll_feats_agg = scatter_sum( + src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs + ) + + return ll_feats_agg + + def compute_covariance( + self, + train_loader: DataLoader, + include_energy: bool = True, + include_forces: bool = False, + include_virials: bool = False, + include_stresses: bool = False, + is_universal: bool = False, + huber_delta: float = 0.01, + ) -> None: + # Utility function to compute the covariance matrix for a training set. + + for batch in train_loader: + batch.to(self.covariance.device) + batch_dict = batch.to_dict() + output = self.orig_model( + batch_dict, + training=(include_forces or include_virials or include_stresses), + compute_force=(is_universal and include_forces), # we need this for the Huber loss force mask + compute_virials=False, + compute_stress=(is_universal and include_stresses), # we need this for the Huber loss force mask + compute_displacement=(include_virials or include_stresses), + ) + + num_graphs = batch_dict["ptr"].numel() - 1 + num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1] + ll_feats = self.aggregate_ll_features(output["node_feats"], batch_dict["batch"], num_graphs) + + if include_forces or include_virials or include_stresses: + f_grads, v_grads, s_grads = compute_ll_feat_gradients( + ll_feats=ll_feats, + displacement=output["displacement"], + batch_dict=batch_dict, + compute_force=include_forces, + compute_virials=include_virials, + compute_stress=include_stresses, + ) + else: + f_grads, v_grads, s_grads = None, None, None + ll_feats = ll_feats.detach() + + if include_energy: + # Account for the weighting of structures and targets + # Apply Huber loss mask if universal model + cur_weights = torch.mul(batch.weight, batch.energy_weight) + if is_universal: + huber_mask = get_huber_mask( + output["energy"], + batch["energy"], + huber_delta, + ) + cur_weights = torch.mul(cur_weights, huber_mask) + ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5)) + self.covariance += (ll_feats / num_atoms.unsqueeze(-1)).T @ (ll_feats / num_atoms.unsqueeze(-1)) + + if include_forces: + # Account for the weighting of structures and targets + # Apply Huber loss mask if universal model + f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch]) + f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch]) + cur_f_weights = torch.mul(f_conf_weights, f_forces_weights) + cur_f_weights = cur_f_weights.view(-1, 1).expand(-1, 3) + if is_universal: + huber_mask_force = get_conditional_huber_force_mask( + output["forces"], + batch["forces"], + huber_delta, + ) + cur_f_weights = torch.mul(cur_f_weights, huber_mask_force) + f_grads = torch.mul(f_grads, cur_f_weights.unsqueeze(-1)**(0.5)) + f_grads = f_grads.reshape(-1, ll_feats.shape[-1]) + self.covariance += f_grads.T @ f_grads + + if include_virials: + # No Huber mask in the case of virials as it was not used in the + # universal model + cur_v_weights = torch.mul(batch.weight, batch.virials_weight) + cur_v_weights = cur_v_weights.view(-1, 1, 1).expand(-1, 3, 3) + v_grads = torch.mul(v_grads, cur_v_weights.unsqueeze(-1)**(0.5)) + v_grads = v_grads.reshape(-1, ll_feats.shape[-1]) + self.covariance += v_grads.T @ v_grads + + if include_stresses: + # Account for the weighting of structures and targets + # Apply Huber loss mask if universal model + cur_s_weights = torch.mul(batch.weight, batch.stress_weight) + cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3) + if is_universal: + huber_mask_stress = get_huber_mask( + output["stress"], + batch["stress"], + huber_delta, + ) + cur_s_weights = torch.mul(cur_s_weights, huber_mask_stress) + s_grads = torch.mul(s_grads, cur_s_weights.unsqueeze(-1)**(0.5)) + s_grads = s_grads.reshape(-1, ll_feats.shape[-1]) + # The stresses seem to be normalized by n_atoms in the normal loss, but + # not in the universal loss. + if is_universal: + self.covariance += s_grads.T @ s_grads + else: + # repeat num_atoms for 9 elements of stress tensor + self.covariance += (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)).T \ + @ (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)) + + self.covariance_computed = True + + def compute_inv_covariance(self, C: float, sigma: float) -> None: + # Utility function to set the hyperparameters of the uncertainty model. + if not self.covariance_computed: + raise RuntimeError("You must compute the covariance matrix before " + "computing the inverse covariance matrix!") + self.inv_covariance = C * torch.linalg.inv( + self.covariance + sigma**2 * torch.eye(self.hidden_size_sum, device=self.covariance.device) + ) + self.inv_covariance_computed = True + + def reset_matrices(self) -> None: + # Utility function to reset covariance and inv covariance matrices. + self.covariance = torch.zeros(self.covariance.shape, device=self.covariance.device) + self.inv_covariance = torch.zeros(self.covariance.shape, device=self.covariance.device) + self.covariance_computed = False + self.inv_covariance_computed = False + self.covariance_gradients_computed = False diff --git a/mace/modules/utils.py b/mace/modules/utils.py index d0a1e5f6..be2e09f5 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -5,7 +5,7 @@ ########################################################################################### import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict import numpy as np import torch @@ -70,6 +70,109 @@ def compute_forces_virials( return -1 * forces, -1 * virials, stress +@torch.jit.script +def compute_ll_feat_gradients( + ll_feats: torch.Tensor, + displacement: torch.Tensor, + batch_dict: Dict[str, torch.Tensor], + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(ll_feats[:, 0])] + positions = batch_dict["positions"] + + if compute_force and not (compute_virials or compute_stress): + f_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_f = torch.autograd.grad( + [ll_feats[:, i]], + [positions], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + )[0] + if cur_grad_f is None: + cur_grad_f = torch.zeros_like(positions) + f_grads_list.append(cur_grad_f) + f_grads = torch.stack(f_grads_list) + f_grads = f_grads.permute(1, 2, 0) + v_grads = None + s_grads = None + + elif compute_force and (compute_virials or compute_stress): + cell = batch_dict["cell"] + f_grads_list = [] + v_grads_list = [] + s_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_f, cur_grad_v = torch.autograd.grad( + [ll_feats[:, i]], + [positions, displacement], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + ) + if cur_grad_f is None: + cur_grad_f = torch.zeros_like(positions) + f_grads_list.append(cur_grad_f) + if cur_grad_v is None: + cur_grad_v = torch.zeros_like(displacement) + v_grads_list.append(cur_grad_v) + f_grads = torch.stack(f_grads_list) + f_grads = f_grads.permute(1, 2, 0) # [num_atoms_batch, 3, num_ll_feats] + v_grads = torch.stack(v_grads_list) + v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats] + + if compute_stress: + cell = cell.view(-1, 3, 3) + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + s_grads = v_grads / volume.view(-1, 1, 1, 1) + else: + s_grads = None + + elif not compute_force and (compute_virials or compute_stress): + cell = batch_dict["cell"] + v_grads_list = [] + for i in range(ll_feats.shape[-1]): + cur_grad_v = torch.autograd.grad( + [ll_feats[:, i]], + [displacement], + grad_outputs=grad_outputs, + retain_graph=(i != ll_feats.shape[-1] - 1), + create_graph=False, + allow_unused=True, + )[0] + if cur_grad_v is None: + cur_grad_v = torch.zeros_like(displacement) + v_grads_list.append(cur_grad_v) + v_grads = torch.stack(v_grads_list) + v_grads = v_grads.permute(1, 2, 3, 0) # [num_batch, 3, 3, num_ll_feats] + + if compute_stress: + cell = cell.view(-1, 3, 3) + volume = torch.einsum( + "zi,zi->z", + cell[:, 0, :], + torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), + ).unsqueeze(-1) + s_grads = v_grads / volume.view(-1, 1, 1, 1) + else: + s_grads = None + f_grads = None + else: + raise RuntimeError("Unsupported configuration for computing gradients") + + return f_grads, v_grads, s_grads + + def get_symmetric_displacement( positions: torch.Tensor, unit_shifts: torch.Tensor, @@ -163,6 +266,40 @@ def compute_hessians_loop( return hessian +def get_huber_mask( + input: torch.Tensor, + target: torch.Tensor, + huber_delta: float = 1.0, +) -> torch.Tensor: + se = torch.square(input - target) + huber_mask = (se < huber_delta).int().double() + return huber_mask + + +def get_conditional_huber_force_mask( + input: torch.Tensor, + target: torch.Tensor, + huber_delta: float, +) -> torch.Tensor: + # Define the multiplication factors for each condition + factors = huber_delta * torch.tensor([1.0, 0.7, 0.4, 0.1]) + + # Apply multiplication factors based on conditions + c1 = torch.norm(target, dim=-1) < 100 + c2 = (torch.norm(target, dim=-1) >= 100) & (torch.norm(target, dim=-1) < 200) + c3 = (torch.norm(target, dim=-1) >= 200) & (torch.norm(target, dim=-1) < 300) + c4 = ~(c1 | c2 | c3) + + huber_mask = torch.zeros_like(input) + + huber_mask[c1] = get_huber_mask(target[c1], input[c1], factors[0]) + huber_mask[c2] = get_huber_mask(target[c2], input[c2], factors[1]) + huber_mask[c3] = get_huber_mask(target[c3], input[c3], factors[2]) + huber_mask[c4] = get_huber_mask(target[c4], input[c4], factors[3]) + + return huber_mask + + def get_outputs( energy: torch.Tensor, positions: torch.Tensor, diff --git a/mace/tools/__init__.py b/mace/tools/__init__.py index 8ad80243..45797ac4 100644 --- a/mace/tools/__init__.py +++ b/mace/tools/__init__.py @@ -3,6 +3,7 @@ from .cg import U_matrix_real from .checkpoint import CheckpointHandler, CheckpointIO, CheckpointState from .finetuning_utils import load_foundations, load_foundations_elements +from .llpr import calibrate_llpr_params from .torch_tools import ( TensorDict, cartesian_to_spherical, @@ -68,4 +69,5 @@ "load_foundations", "load_foundations_elements", "build_preprocess_arg_parser", + "calibrate_llpr_params", ] diff --git a/mace/tools/llpr.py b/mace/tools/llpr.py new file mode 100644 index 00000000..2a94d74a --- /dev/null +++ b/mace/tools/llpr.py @@ -0,0 +1,125 @@ +import torch +import numpy as np +from scipy.optimize import brute +import math + + +def calibrate_llpr_params( + model, + validation_loader, + function="ssl", + calib_bound=5, + calib_delta=0.1, + **kwargs, +): + # This function optimizes the calibration parameters for LLPR on the validation set + # Original author: F. Bigi (@frostedoyster) + + if function == "ssl": + obj_function = _sum_squared_log + elif function == "nll": + obj_function = _avg_nll_regression + else: + raise RuntimeError("Unsupported objective function type for LLPR uncertainty calibration!") + + actual_errors = [] + ll_feats = [] + num_atoms = [] + # Compute model predictions, actual errors, ll feats + for batch in validation_loader: + batch = batch.to(next(model.parameters()).device) + batch_dict = batch.to_dict() + y = batch_dict['energy'] + num_graphs = batch_dict["ptr"].numel() - 1 + num_atoms.append(batch_dict['ptr'][1:] - batch_dict['ptr'][:-1]) + model_outputs = model(batch_dict) + predictions = model_outputs['energy'].detach() + cur_ll_feats = model.aggregate_ll_features( + model_outputs["node_feats"], batch_dict["batch"], num_graphs + ).detach() + ll_feats.append(cur_ll_feats) + actual_errors.append((y - predictions)**2) + + actual_errors = torch.cat(actual_errors, dim=0) + ll_feats_all = torch.cat(ll_feats, dim=0) + num_atoms = torch.cat(num_atoms, dim=0) + + def obj_function_wrapper(x): + x = _process_inputs(x) + try: + model.compute_inv_covariance(*x) + predicted_errors = torch.einsum( + "ij, jk, ik -> i", + ll_feats_all, + model.inv_covariance, + ll_feats_all + ) + obj_value = obj_function(actual_errors, predicted_errors, **kwargs) + except torch._C._LinAlgError: + obj_value = 1e10 + if math.isnan(obj_value): + obj_value = 1e10 + return obj_value + calib_slice = slice(-1*calib_bound, calib_bound+0.01, calib_delta) + result = brute(obj_function_wrapper, ranges=[calib_slice, calib_slice]) + + # warn if we hit the edge of the parameter space + if result[0] <= -5 or result[0] >= 5 or result[1] <= -5 or result[1] >= 5: + print("Optimal parameters found beyond the designated parameter space!") + + print(f"Calibrated LLPR parameters:\tC = {10**result[0]:.4E}\tsigma = {10**result[1]:.4E}") + model.compute_inv_covariance(*(_process_inputs(result))) + + +def _process_inputs(x): + x = list(x) + x = [10**single_x for single_x in x] + return x + + +def _avg_nll_regression(actual_errors, predicted_errors, energy_shift=0.0, energy_scale=1.0): + # This function calculates the negative log-likelihood on the energy for a dataset + # Original author: F. Bigi (@frostedoyster) + total_nll = ( + actual_errors / predicted_errors + torch.log(actual_errors) + np.log(2*np.pi) + ).sum().item() * 0.5 + return total_nll / len(actual_errors) + + +def _sum_squared_log(actual_errors, predicted_errors, n_samples_per_bin=1): + # This function calculates the sum of squared log errors on the energy for a dataset + # Original author: F. Bigi (@frostedoyster) + sort_indices = torch.argsort(predicted_errors) + actual_errors_sorted = actual_errors[sort_indices] + predicted_errors_sorted = predicted_errors[sort_indices] + + n_samples = len(actual_errors) + + actual_error_bins = [] + predicted_error_bins = [] + + # skip the last bin for incompleteness + for i_bin in range(n_samples // n_samples_per_bin - 1): + actual_error_bins.append( + actual_errors_sorted[i_bin*n_samples_per_bin:(i_bin+1)*n_samples_per_bin] + ) + predicted_error_bins.append( + predicted_errors_sorted[i_bin*n_samples_per_bin:(i_bin+1)*n_samples_per_bin] + ) + + actual_error_bins = torch.stack(actual_error_bins) + predicted_error_bins = torch.stack(predicted_error_bins) + + # calculate means: + actual_error_means = actual_error_bins.mean(dim=1) + predicted_error_means = predicted_error_bins.mean(dim=1) + + # calculate squared log errors: + squared_log_errors = ( + torch.log(actual_error_means / predicted_error_means)**2 + ) + + # calculate the sum of squared log errors: + sum_squared_log = squared_log_errors.sum().item() + + return sum_squared_log