Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Last-layer prediction rigidity (LLPR)-based uncertainty quantification for MACE #601

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ dist/
*.xyz
/checkpoints
*.model

.history
1 change: 1 addition & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
EnergyDipolesMACE,
ScaleShiftBOTNet,
ScaleShiftMACE,
LLPRModel,
)
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .symmetric_contraction import SymmetricContraction
Expand Down
289 changes: 289 additions & 0 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,16 @@
from .utils import (
compute_fixed_charge_dipole,
compute_forces,
compute_ll_feat_gradients,
get_conditional_huber_force_mask,
get_edge_vectors_and_lengths,
get_huber_mask,
get_outputs,
get_symmetric_displacement,

)

from torch.utils.data import DataLoader
# pylint: disable=C0302


Expand Down Expand Up @@ -314,6 +319,20 @@ def forward(
}


def readout_is_linear(obj: Any):
if isinstance(obj, torch.jit.RecursiveScriptModule):
return obj.original_name == "LinearReadoutBlock"
else:
return isinstance(obj, LinearReadoutBlock)


def readout_is_nonlinear(obj: Any):
if isinstance(obj, torch.jit.RecursiveScriptModule):
return obj.original_name == "NonLinearReadoutBlock"
else:
return isinstance(obj, NonLinearReadoutBlock)


@compile_mode("script")
class ScaleShiftMACE(MACE):
def __init__(
Expand Down Expand Up @@ -1107,3 +1126,273 @@ def forward(
"atomic_dipoles": atomic_dipoles,
}
return output


@compile_mode("script")
class LLPRModel(torch.nn.Module):
def __init__(
self,
model: Union[MACE, ScaleShiftMACE],
ll_feat_format: str = "avg",
):
super().__init__()
self.orig_model = model

# determine ll_feat size from readout layers
self.hidden_sizes_before_readout = []
self.hidden_sizes = []
self.readouts_are_linear = []
self.hidden_size_sum = 0
for readout in self.orig_model.readouts.children():
if readout_is_linear(readout):
self.readouts_are_linear.append(True)
cur_size_before_readout = sum(irrep.dim for irrep in o3.Irreps(readout.linear.irreps_in))
self.hidden_sizes_before_readout.append(cur_size_before_readout)
cur_size = o3.Irreps(readout.linear.irreps_in)[0].dim
self.hidden_sizes.append(cur_size)
self.hidden_size_sum += cur_size
elif readout_is_nonlinear(readout):
self.readouts_are_linear.append(False)
cur_size_before_readout = sum(irrep.dim for irrep in o3.Irreps(readout.linear_1.irreps_in))
self.hidden_sizes_before_readout.append(cur_size_before_readout)
cur_size = o3.Irreps(readout.linear_2.irreps_in).dim
self.hidden_sizes.append(cur_size)
self.hidden_size_sum += cur_size
else:
raise TypeError("Unknown readout block type for LLPR at initialization!")

# initialize (inv_)covariance matrices
self.register_buffer("covariance",
torch.zeros((self.hidden_size_sum, self.hidden_size_sum),
device=next(self.orig_model.parameters()).device
)
)
self.register_buffer("inv_covariance",
torch.zeros((self.hidden_size_sum, self.hidden_size_sum),
device=next(self.orig_model.parameters()).device
)
)

# extra params associated with LLPR
self.ll_feat_format = ll_feat_format
self.covariance_computed = False
self.covariance_gradients_computed = False
self.inv_covariance_computed = False

def forward(
self,
data: Dict[str, torch.Tensor],
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_energy_uncertainty: bool = True,
compute_force_uncertainty: bool = False,
compute_virial_uncertainty: bool = False,
compute_stress_uncertainty: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
if compute_force_uncertainty and not compute_force:
raise RuntimeError("Cannot compute force uncertainty without computing forces")
if compute_virial_uncertainty and not compute_virials:
raise RuntimeError("Cannot compute virial uncertainty without computing virials")
if compute_stress_uncertainty and not compute_stress:
raise RuntimeError("Cannot compute stress uncertainty without computing stress")

num_graphs = data["ptr"].numel() - 1

output = self.orig_model(
data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement
)
ll_feats = self.aggregate_ll_features(output["node_feats"], data["batch"], num_graphs)

energy_uncertainty = None
force_uncertainty = None
virial_uncertainty = None
stress_uncertainty = None
if self.inv_covariance_computed:
if compute_force_uncertainty or compute_virial_uncertainty or compute_stress_uncertainty:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
ll_feats=ll_feats,
displacement=output["displacement"],
batch_dict=data.to_dict(),
compute_force=compute_force_uncertainty,
compute_virials=compute_virial_uncertainty,
compute_stress=compute_stress_uncertainty,
)
else:
f_grads, v_grads, s_grads = None, None, None
ll_feats = ll_feats.detach()
if compute_energy_uncertainty:
energy_uncertainty = torch.einsum("ij, jk, ik -> i",
ll_feats,
self.inv_covariance,
ll_feats
)
if compute_force_uncertainty:
force_uncertainty = torch.einsum("iaj, jk, iak -> ia",
f_grads,
self.inv_covariance,
f_grads
)
if compute_virial_uncertainty:
virial_uncertainty = torch.einsum("iabj, jk, iabk -> iab",
v_grads,
self.inv_covariance,
v_grads
)
if compute_stress_uncertainty:
stress_uncertainty = torch.einsum("iabj, jk, iabk -> iab",
s_grads,
self.inv_covariance,
s_grads
)

output["energy_uncertainty"] = energy_uncertainty
output["force_uncertainty"] = force_uncertainty
output["virial_uncertainty"] = virial_uncertainty
output["stress_uncertainty"] = stress_uncertainty

return output

def aggregate_ll_features(
self,
ll_feats: torch.Tensor,
indices: torch.Tensor,
num_graphs: int
) -> torch.Tensor:
# Aggregates (sums) node features over each structure
ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1)
ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)]

# Aggregate node features
ll_feats_cat = torch.cat(ll_feats_list, dim=-1)
ll_feats_agg = scatter_sum(
src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs
)

return ll_feats_agg

def compute_covariance(
self,
train_loader: DataLoader,
include_energy: bool = True,
include_forces: bool = False,
include_virials: bool = False,
include_stresses: bool = False,
is_universal: bool = False,
huber_delta: float = 0.01,
) -> None:
# Utility function to compute the covariance matrix for a training set.

for batch in train_loader:
batch.to(self.covariance.device)
batch_dict = batch.to_dict()
output = self.orig_model(
batch_dict,
training=(include_forces or include_virials or include_stresses),
compute_force=(is_universal and include_forces), # we need this for the Huber loss force mask
compute_virials=False,
compute_stress=(is_universal and include_stresses), # we need this for the Huber loss force mask
compute_displacement=(include_virials or include_stresses),
)

num_graphs = batch_dict["ptr"].numel() - 1
num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1]
ll_feats = self.aggregate_ll_features(output["node_feats"], batch_dict["batch"], num_graphs)

if include_forces or include_virials or include_stresses:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
ll_feats=ll_feats,
displacement=output["displacement"],
batch_dict=batch_dict,
compute_force=include_forces,
compute_virials=include_virials,
compute_stress=include_stresses,
)
else:
f_grads, v_grads, s_grads = None, None, None
ll_feats = ll_feats.detach()

if include_energy:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
cur_weights = torch.mul(batch.weight, batch.energy_weight)
if is_universal:
huber_mask = get_huber_mask(
output["energy"],
batch["energy"],
huber_delta,
)
cur_weights = torch.mul(cur_weights, huber_mask)
ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5))
self.covariance += (ll_feats / num_atoms.unsqueeze(-1)).T @ (ll_feats / num_atoms.unsqueeze(-1))

if include_forces:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch])
f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch])
cur_f_weights = torch.mul(f_conf_weights, f_forces_weights)
cur_f_weights = cur_f_weights.view(-1, 1).expand(-1, 3)
if is_universal:
huber_mask_force = get_conditional_huber_force_mask(
output["forces"],
batch["forces"],
huber_delta,
)
cur_f_weights = torch.mul(cur_f_weights, huber_mask_force)
f_grads = torch.mul(f_grads, cur_f_weights.unsqueeze(-1)**(0.5))
f_grads = f_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += f_grads.T @ f_grads

if include_virials:
# No Huber mask in the case of virials as it was not used in the
# universal model
cur_v_weights = torch.mul(batch.weight, batch.virials_weight)
cur_v_weights = cur_v_weights.view(-1, 1, 1).expand(-1, 3, 3)
v_grads = torch.mul(v_grads, cur_v_weights.unsqueeze(-1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += v_grads.T @ v_grads

if include_stresses:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
cur_s_weights = torch.mul(batch.weight, batch.stress_weight)
cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3)
if is_universal:
huber_mask_stress = get_huber_mask(
output["stress"],
batch["stress"],
huber_delta,
)
cur_s_weights = torch.mul(cur_s_weights, huber_mask_stress)
s_grads = torch.mul(s_grads, cur_s_weights.unsqueeze(-1)**(0.5))
s_grads = s_grads.reshape(-1, ll_feats.shape[-1])
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss.
if is_universal:
self.covariance += s_grads.T @ s_grads
else:
# repeat num_atoms for 9 elements of stress tensor
self.covariance += (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)).T \
@ (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1))

self.covariance_computed = True

def compute_inv_covariance(self, C: float, sigma: float) -> None:
# Utility function to set the hyperparameters of the uncertainty model.
if not self.covariance_computed:
raise RuntimeError("You must compute the covariance matrix before "
"computing the inverse covariance matrix!")
self.inv_covariance = C * torch.linalg.inv(
self.covariance + sigma**2 * torch.eye(self.hidden_size_sum, device=self.covariance.device)
)
self.inv_covariance_computed = True

def reset_matrices(self) -> None:
# Utility function to reset covariance and inv covariance matrices.
self.covariance = torch.zeros(self.covariance.shape, device=self.covariance.device)
self.inv_covariance = torch.zeros(self.covariance.shape, device=self.covariance.device)
self.covariance_computed = False
self.inv_covariance_computed = False
self.covariance_gradients_computed = False
Loading