Skip to content

Commit

Permalink
Run the original model inside the wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Sep 23, 2024
1 parent ab38be6 commit 1a57319
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 141 deletions.
1 change: 1 addition & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
EnergyDipolesMACE,
ScaleShiftBOTNet,
ScaleShiftMACE,
LLPRModel,
)
from .radial import BesselBasis, GaussianBasis, PolynomialCutoff, ZBLBasis
from .symmetric_contraction import SymmetricContraction
Expand Down
255 changes: 114 additions & 141 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ def forward(
}


def readout_is_linear(obj: Any):
if isinstance(obj, torch.jit.RecursiveScriptModule):
return obj.original_name == "LinearReadoutBlock"
else:
return isinstance(obj, LinearReadoutBlock)


def readout_is_nonlinear(obj: Any):
if isinstance(obj, torch.jit.RecursiveScriptModule):
return obj.original_name == "NonLinearReadoutBlock"
else:
return isinstance(obj, NonLinearReadoutBlock)


@compile_mode("script")
class ScaleShiftMACE(MACE):
def __init__(
Expand Down Expand Up @@ -1114,6 +1128,7 @@ class LLPRModel(torch.nn.Module):
def __init__(
self,
model: Union[MACE, ScaleShiftMACE],
ll_feat_format: str = "avg",
):
super().__init__()
self.orig_model = model
Expand Down Expand Up @@ -1154,148 +1169,80 @@ def __init__(
)

# extra params associated with LLPR
self.ll_feat_format = ll_feat_format
self.covariance_computed = False
self.covariance_gradients_computed = False
self.inv_covariance_computed = False

def aggregate_features(self, ll_feats: torch.Tensor, indices: torch.Tensor, num_graphs: int, num_atoms: torch.Tensor) -> torch.Tensor:
ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1)
ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)]

# Aggregate node features
ll_feats_cat = torch.cat(ll_feats_list, dim=-1)
ll_feats_agg = scatter_sum(
src=ll_feats_cat, index=indices, dim=0, dim_size=num_graphs
)

return ll_feats_agg
self.already_printed = False

def forward(
self,
data: Dict[str, torch.Tensor],
training: bool = False,
compute_force: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
compute_displacement: bool = False,
compute_energy_uncertainty: bool = True,
compute_force_uncertainty: bool = False,
compute_virial_uncertainty: bool = False,
compute_stress_uncertainty: bool = False,
) -> Dict[str, Optional[torch.Tensor]]:
if compute_force_uncertainty and not compute_force:
raise RuntimeError("Cannot compute force uncertainty without computing forces")
if compute_virial_uncertainty and not compute_virials:
raise RuntimeError("Cannot compute virial uncertainty without computing virials")
if compute_stress_uncertainty and not compute_stress:
raise RuntimeError("Cannot compute stress uncertainty without computing stress")

num_graphs = data["ptr"].numel() - 1
num_atoms = data["ptr"][1:] - data["ptr"][:-1]

output = self.orig_model(
data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement
data, training, compute_force, compute_virials, compute_stress, compute_displacement
)

ll_feats = output["node_feats"]
ll_feats_list = torch.split(ll_feats, self.hidden_sizes_before_readout, dim=-1)
ll_feats_list = [(ll_feats if is_linear else readout.non_linearity(readout.linear_1(ll_feats)))[:, :size] for ll_feats, readout, size, is_linear in zip(ll_feats_list, self.orig_model.readouts.children(), self.hidden_sizes, self.readouts_are_linear)]

# Aggregate node features
ll_feats_cat = torch.cat(ll_feats_list, dim=-1)
ll_feats_agg = scatter_sum(
src=ll_feats_cat, index=data["batch"], dim=0, dim_size=num_graphs
)
ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs, num_atoms)

energy_uncertainty = None
force_uncertainty = None
virial_uncertainty = None
stress_uncertainty = None
if self.ll_feat_format == "sum":
ll_feats_out = ll_feats_agg

elif self.ll_feat_format == "avg":
ll_feats_out = torch.div(ll_feats_agg, num_atoms.unsqueeze(-1))

elif self.ll_feat_format == "raw":
ll_feats_out = ll_feats_cat

else:
raise RuntimeError("Unsupported last layer feature format!")

# return uncertainty if inv_covariance matrix is available
if self.inv_covariance_computed:
if compute_force_uncertainty or compute_virial_uncertainty or compute_stress_uncertainty:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
ll_feats=ll_feats,
displacement=output["displacement"],
batch_dict=data.to_dict(),
compute_force=compute_force_uncertainty,
compute_virials=compute_virial_uncertainty,
compute_stress=compute_stress_uncertainty,
)
else:
f_grads, v_grads, s_grads = None, None, None
ll_feats = ll_feats.detach()
if compute_energy_uncertainty:
energy_uncertainty = torch.einsum("ij, jk, ik -> i",
ll_feats,
self.inv_covariance,
ll_feats
)
if compute_force_uncertainty:
force_uncertainty = torch.einsum("iaj, jk, iak -> ia",
f_grads,
self.inv_covariance,
f_grads
)
if compute_virial_uncertainty:
virial_uncertainty = torch.einsum("iabj, jk, iabk -> iab",
v_grads,
self.inv_covariance,
v_grads
)
if compute_stress_uncertainty:
stress_uncertainty = torch.einsum("iabj, jk, iabk -> iab",
s_grads,
self.inv_covariance,
s_grads
)

output["energy_uncertainty"] = energy_uncertainty
output["force_uncertainty"] = force_uncertainty
output["virial_uncertainty"] = virial_uncertainty
output["stress_uncertainty"] = stress_uncertainty
uncertainty = torch.einsum("ij, jk, ik -> i",
ll_feats_agg,
self.inv_covariance,
ll_feats_agg
)
uncertainty = uncertainty.unsqueeze(1)
else:
uncertainty = None

output["ll_feats"] = ll_feats_out
output["uncertainty"] = uncertainty

return output

def compute_covariance(
self,
train_loader: DataLoader,
include_forces: bool = False,
include_virials: bool = False,
include_stresses: bool = False,
is_universal: bool = False,
huber_delta: float = 0.01,
huber_delta: float = 0.1,
) -> None:
# if not is_universal:
# raise NotImplementedError("Only universal loss models are supported for LLPR")

import tqdm
# Utility function to compute the covariance matrix for a training set.
# Note that this function computes the covariance step-wise, so it can
# be used to accumulate multiple times on subsets of the same training set

if not is_universal:
raise NotImplementedError("Only universal loss models are supported for LLPR")

import tqdm
for batch in tqdm.tqdm(train_loader):
batch.to(self.covariance.device)
batch_dict = batch.to_dict()
output = self.orig_model(
batch_dict,
training=(include_forces or include_virials or include_stresses),
compute_force=(is_universal and include_forces), # we need this for the Huber loss force mask
compute_virials=False,
compute_stress=False,
compute_displacement=(include_virials or include_stresses),
)

num_graphs = batch_dict["ptr"].numel() - 1
num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1]
ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs, num_atoms)

if include_forces or include_virials or include_stresses:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
ll_feats=ll_feats,
displacement=output["displacement"],
batch_dict=batch_dict,
compute_force=include_forces,
compute_virials=include_virials,
compute_stress=include_stresses,
)
else:
f_grads, v_grads, s_grads = None, None, None
ll_feats = ll_feats.detach()

output = self.forward(batch_dict)
ll_feats = output["ll_feats"].detach()
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
cur_weights = torch.mul(batch.weight, batch.energy_weight)
Expand All @@ -1306,54 +1253,80 @@ def compute_covariance(
huber_delta,
)
cur_weights *= huber_mask
ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5))
ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1))
self.covariance += ll_feats.T @ ll_feats
self.covariance_computed = True

if include_forces:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch])
f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch])
cur_f_weights = torch.mul(f_conf_weights, f_forces_weights)
if is_universal:
huber_mask_force = get_conditional_huber_force_mask(
output["forces"],
batch["forces"],
huber_delta,
)
cur_f_weights *= huber_mask_force
f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)**(0.5))
f_grads = f_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += f_grads.T @ f_grads
def add_gradients_to_covariance(
self,
train_loader: DataLoader,
training: bool = True,
compute_virials: bool = False,
compute_stress: bool = False,
is_universal: bool = False,
huber_delta: float = 0.1,
) -> None:

compute_displacement = compute_virials or compute_stress

for batch in train_loader:
batch.to(self.covariance.device)
batch_dict = batch.to_dict()
output = self.forward(batch_dict,
training=training,
compute_displacement=compute_displacement,
compute_force=True,
compute_virials=compute_virials,
compute_stress=compute_stress,
)
ll_feats = output["ll_feats"]

f_grads, v_grads, s_grads = compute_ll_feat_gradients(
ll_feats=ll_feats,
displacement=output["displacement"],
batch_dict=batch_dict,
training=training,
compute_virials=compute_virials,
compute_stress=compute_stress,
)

if include_virials:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch])
f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch])
cur_f_weights = torch.mul(f_conf_weights, f_forces_weights)
if is_universal:
huber_mask_force = get_conditional_huber_force_mask(
output["forces"],
batch["forces"],
huber_delta,
)
cur_f_weights *= huber_mask_force
f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1))
f_grads = f_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += f_grads.T @ f_grads

if compute_virials:
cur_v_weights = torch.mul(batch.weight, batch.virials_weight)
v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
# AS FAR AS I KNOW NO VIRIALS IN THE UNIVERSAL MODEL
v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += v_grads.T @ v_grads

if include_stresses:
if compute_stress:
cur_s_weights = torch.mul(batch.weight, batch.stress_weight)
cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3)
if is_universal:
huber_mask_stress = get_huber_mask(
output["stress"],
batch["stress"],
huber_delta,
)
cur_s_weights *= huber_mask_stress
s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)**(0.5))
s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1))
s_grads = s_grads.reshape(-1, ll_feats.shape[-1])
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss.
if is_universal:
self.covariance += s_grads.T @ s_grads
else:
# repeat num_atoms for 9 elements of stress tensor
self.covariance += (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)).T \
@ (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1))
self.covariance += s_grads.T @ s_grads

self.covariance_computed = True
self.covariance_gradients_computed = True

def compute_inv_covariance(self, C: float, sigma: float) -> None:
# Utility function to set the hyperparameters of the uncertainty model.
Expand Down

0 comments on commit 1a57319

Please sign in to comment.