diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130f..24555784 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -37,6 +37,7 @@ EnergyDipolesMACE, ScaleShiftBOTNet, ScaleShiftMACE, + LLPRModel, ) from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis from .symmetric_contraction import SymmetricContraction diff --git a/mace/modules/models.py b/mace/modules/models.py index 3474c6df..c7ced958 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -314,6 +314,20 @@ def forward( } +def readout_is_linear(obj: Any): + if isinstance(obj, torch.jit.RecursiveScriptModule): + return obj.original_name == "LinearReadoutBlock" + else: + return isinstance(obj, LinearReadoutBlock) + + +def readout_is_nonlinear(obj: Any): + if isinstance(obj, torch.jit.RecursiveScriptModule): + return obj.original_name == "NonLinearReadoutBlock" + else: + return isinstance(obj, NonLinearReadoutBlock) + + @compile_mode("script") class ScaleShiftMACE(MACE): def __init__( @@ -1114,6 +1128,7 @@ class LLPRModel(torch.nn.Module): def __init__( self, model: Union[MACE, ScaleShiftMACE], + ll_feat_format: str = "avg", ): super().__init__() self.orig_model = model @@ -1154,148 +1169,80 @@ def __init__( ) # extra params associated with LLPR + self.ll_feat_format = ll_feat_format self.covariance_computed = False self.covariance_gradients_computed = False self.inv_covariance_computed = False - def aggregate_features(self, ll_feats: torch.Tensor, indices: torch.Tensor, num_graphs: int, num_atoms: torch.Tensor) -> torch.Tensor: - ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1) - ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)] - - # Aggregate node features - ll_feats_cat = torch.cat(ll_feats_list, dim=-1) - ll_feats_agg = scatter_sum( - src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs - ) - - return ll_feats_agg + self.already_printed = False def forward( self, data: Dict[str, torch.Tensor], + training: bool = False, compute_force: bool = True, compute_virials: bool = False, compute_stress: bool = False, compute_displacement: bool = False, - compute_energy_uncertainty: bool = True, - compute_force_uncertainty: bool = False, - compute_virial_uncertainty: bool = False, - compute_stress_uncertainty: bool = False, ) -> Dict[str, Optional[torch.Tensor]]: - if compute_force_uncertainty and not compute_force: - raise RuntimeError("Cannot compute force uncertainty without computing forces") - if compute_virial_uncertainty and not compute_virials: - raise RuntimeError("Cannot compute virial uncertainty without computing virials") - if compute_stress_uncertainty and not compute_stress: - raise RuntimeError("Cannot compute stress uncertainty without computing stress") - num_graphs = data["ptr"].numel() - 1 num_atoms = data["ptr"][1:] - data["ptr"][:-1] output = self.orig_model( - data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement + data, training, compute_force, compute_virials, compute_stress, compute_displacement + ) + + ll_feats = output["node_feats"] + ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1) + ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)] + + # Aggregate node features + ll_feats_cat = torch.cat(ll_feats_list, dim=-1) + ll_feats_agg = scatter_sum( + src=ll_feats_cat, index=data["batch"], dim=0, dim_size=num_graphs ) - ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs, num_atoms) - energy_uncertainty = None - force_uncertainty = None - virial_uncertainty = None - stress_uncertainty = None + if self.ll_feat_format == "sum": + ll_feats_out = ll_feats_agg + + elif self.ll_feat_format == "avg": + ll_feats_out = torch.div(ll_feats_agg, num_atoms.unsqueeze(-1)) + + elif self.ll_feat_format == "raw": + ll_feats_out = ll_feats_cat + + else: + raise RuntimeError("Unsupported last layer feature format!") + + # return uncertainty if inv_covariance matrix is available if self.inv_covariance_computed: - if compute_force_uncertainty or compute_virial_uncertainty or compute_stress_uncertainty: - f_grads, v_grads, s_grads = compute_ll_feat_gradients( - ll_feats=ll_feats, - displacement=output["displacement"], - batch_dict=data.to_dict(), - compute_force=compute_force_uncertainty, - compute_virials=compute_virial_uncertainty, - compute_stress=compute_stress_uncertainty, - ) - else: - f_grads, v_grads, s_grads = None, None, None - ll_feats = ll_feats.detach() - if compute_energy_uncertainty: - energy_uncertainty = torch.einsum("ij, jk, ik -> i", - ll_feats, - self.inv_covariance, - ll_feats - ) - if compute_force_uncertainty: - force_uncertainty = torch.einsum("iaj, jk, iak -> ia", - f_grads, - self.inv_covariance, - f_grads - ) - if compute_virial_uncertainty: - virial_uncertainty = torch.einsum("iabj, jk, iabk -> iab", - v_grads, - self.inv_covariance, - v_grads - ) - if compute_stress_uncertainty: - stress_uncertainty = torch.einsum("iabj, jk, iabk -> iab", - s_grads, - self.inv_covariance, - s_grads - ) - - output["energy_uncertainty"] = energy_uncertainty - output["force_uncertainty"] = force_uncertainty - output["virial_uncertainty"] = virial_uncertainty - output["stress_uncertainty"] = stress_uncertainty + uncertainty = torch.einsum("ij, jk, ik -> i", + ll_feats_agg, + self.inv_covariance, + ll_feats_agg + ) + uncertainty = uncertainty.unsqueeze(1) + else: + uncertainty = None + + output["ll_feats"] = ll_feats_out + output["uncertainty"] = uncertainty return output def compute_covariance( self, train_loader: DataLoader, - include_forces: bool = False, - include_virials: bool = False, - include_stresses: bool = False, is_universal: bool = False, - huber_delta: float = 0.01, + huber_delta: float = 0.1, ) -> None: - # if not is_universal: - # raise NotImplementedError("Only universal loss models are supported for LLPR") - import tqdm # Utility function to compute the covariance matrix for a training set. - # Note that this function computes the covariance step-wise, so it can - # be used to accumulate multiple times on subsets of the same training set - - if not is_universal: - raise NotImplementedError("Only universal loss models are supported for LLPR") - - import tqdm for batch in tqdm.tqdm(train_loader): batch.to(self.covariance.device) batch_dict = batch.to_dict() - output = self.orig_model( - batch_dict, - training=(include_forces or include_virials or include_stresses), - compute_force=(is_universal and include_forces), # we need this for the Huber loss force mask - compute_virials=False, - compute_stress=False, - compute_displacement=(include_virials or include_stresses), - ) - - num_graphs = batch_dict["ptr"].numel() - 1 - num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1] - ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs, num_atoms) - - if include_forces or include_virials or include_stresses: - f_grads, v_grads, s_grads = compute_ll_feat_gradients( - ll_feats=ll_feats, - displacement=output["displacement"], - batch_dict=batch_dict, - compute_force=include_forces, - compute_virials=include_virials, - compute_stress=include_stresses, - ) - else: - f_grads, v_grads, s_grads = None, None, None - ll_feats = ll_feats.detach() - + output = self.forward(batch_dict) + ll_feats = output["ll_feats"].detach() # Account for the weighting of structures and targets # Apply Huber loss mask if universal model cur_weights = torch.mul(batch.weight, batch.energy_weight) @@ -1306,35 +1253,68 @@ def compute_covariance( huber_delta, ) cur_weights *= huber_mask - ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5)) + ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)) self.covariance += ll_feats.T @ ll_feats + self.covariance_computed = True - if include_forces: - # Account for the weighting of structures and targets - # Apply Huber loss mask if universal model - f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch]) - f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch]) - cur_f_weights = torch.mul(f_conf_weights, f_forces_weights) - if is_universal: - huber_mask_force = get_conditional_huber_force_mask( - output["forces"], - batch["forces"], - huber_delta, - ) - cur_f_weights *= huber_mask_force - f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)**(0.5)) - f_grads = f_grads.reshape(-1, ll_feats.shape[-1]) - self.covariance += f_grads.T @ f_grads + def add_gradients_to_covariance( + self, + train_loader: DataLoader, + training: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + is_universal: bool = False, + huber_delta: float = 0.1, + ) -> None: + + compute_displacement = compute_virials or compute_stress + + for batch in train_loader: + batch.to(self.covariance.device) + batch_dict = batch.to_dict() + output = self.forward(batch_dict, + training=training, + compute_displacement=compute_displacement, + compute_force=True, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) + ll_feats = output["ll_feats"] + + f_grads, v_grads, s_grads = compute_ll_feat_gradients( + ll_feats=ll_feats, + displacement=output["displacement"], + batch_dict=batch_dict, + training=training, + compute_virials=compute_virials, + compute_stress=compute_stress, + ) - if include_virials: + # Account for the weighting of structures and targets + # Apply Huber loss mask if universal model + f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch]) + f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch]) + cur_f_weights = torch.mul(f_conf_weights, f_forces_weights) + if is_universal: + huber_mask_force = get_conditional_huber_force_mask( + output["forces"], + batch["forces"], + huber_delta, + ) + cur_f_weights *= huber_mask_force + f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)) + f_grads = f_grads.reshape(-1, ll_feats.shape[-1]) + self.covariance += f_grads.T @ f_grads + + if compute_virials: cur_v_weights = torch.mul(batch.weight, batch.virials_weight) - v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)**(0.5)) - v_grads = v_grads.reshape(-1, ll_feats.shape[-1]) + # AS FAR AS I KNOW NO VIRIALS IN THE UNIVERSAL MODEL + v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)) + v_grads = v_grads.reshape(-1, ll_feats.shape[-1]) self.covariance += v_grads.T @ v_grads - if include_stresses: + if compute_stress: cur_s_weights = torch.mul(batch.weight, batch.stress_weight) - cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3) if is_universal: huber_mask_stress = get_huber_mask( output["stress"], @@ -1342,18 +1322,11 @@ def compute_covariance( huber_delta, ) cur_s_weights *= huber_mask_stress - s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)**(0.5)) + s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)) s_grads = s_grads.reshape(-1, ll_feats.shape[-1]) - # The stresses seem to be normalized by n_atoms in the normal loss, but - # not in the universal loss. - if is_universal: - self.covariance += s_grads.T @ s_grads - else: - # repeat num_atoms for 9 elements of stress tensor - self.covariance += (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)).T \ - @ (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)) + self.covariance += s_grads.T @ s_grads - self.covariance_computed = True + self.covariance_gradients_computed = True def compute_inv_covariance(self, C: float, sigma: float) -> None: # Utility function to set the hyperparameters of the uncertainty model.