Skip to content

Commit

Permalink
bug fix for huber mask
Browse files Browse the repository at this point in the history
  • Loading branch information
SanggyuChong committed Sep 20, 2024
1 parent 5689cc3 commit d11af04
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 29 deletions.
49 changes: 27 additions & 22 deletions mace/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,12 +1197,11 @@ def forward(
raise RuntimeError("Cannot compute stress uncertainty without computing stress")

num_graphs = data["ptr"].numel() - 1
num_atoms = data["ptr"][1:] - data["ptr"][:-1]


output = self.orig_model(
data, (compute_force_uncertainty or compute_stress_uncertainty or compute_virial_uncertainty), compute_force, compute_virials, compute_stress, compute_displacement
)
ll_feats = self.aggregate_features(output["node_feats"], data["batch"], num_graphs)
ll_feats = self.aggregate_ll_features(output["node_feats"], data["batch"], num_graphs)

energy_uncertainty = None
force_uncertainty = None
Expand Down Expand Up @@ -1253,7 +1252,7 @@ def forward(

return output

def aggregate_features(
def aggregate_ll_features(
self,
ll_feats: torch.Tensor,
indices: torch.Tensor,
Expand Down Expand Up @@ -1282,14 +1281,8 @@ def compute_covariance(
huber_delta: float = 0.01,
) -> None:
# Utility function to compute the covariance matrix for a training set.
# Note that this function computes the covariance step-wise, so it can
# be used to accumulate multiple times on subsets of the same training set

if not is_universal:
raise NotImplementedError("Only universal loss models are supported for LLPR")

import tqdm
for batch in tqdm.tqdm(train_loader):
for batch in train_loader:
batch.to(self.covariance.device)
batch_dict = batch.to_dict()
output = self.orig_model(
Expand All @@ -1303,7 +1296,7 @@ def compute_covariance(

num_graphs = batch_dict["ptr"].numel() - 1
num_atoms = batch_dict["ptr"][1:] - batch_dict["ptr"][:-1]
ll_feats = self.aggregate_features(output["node_feats"], batch_dict["batch"], num_graphs)
ll_feats = self.aggregate_ll_features(output["node_feats"], batch_dict["batch"], num_graphs)

if include_forces or include_virials or include_stresses:
f_grads, v_grads, s_grads = compute_ll_feat_gradients(
Expand All @@ -1328,7 +1321,7 @@ def compute_covariance(
batch["energy"],
huber_delta,
)
cur_weights *= huber_mask
cur_weights = torch.mul(cur_weights, huber_mask)
ll_feats = torch.mul(ll_feats, cur_weights.unsqueeze(-1)**(0.5))
self.covariance += (ll_feats / num_atoms.unsqueeze(-1)).T @ (ll_feats / num_atoms.unsqueeze(-1))

Expand All @@ -1338,37 +1331,49 @@ def compute_covariance(
f_conf_weights = torch.stack([batch.weight[ii] for ii in batch.batch])
f_forces_weights = torch.stack([batch.forces_weight[ii] for ii in batch.batch])
cur_f_weights = torch.mul(f_conf_weights, f_forces_weights)
cur_f_weights = cur_f_weights.view(-1, 1).expand(-1, 3)
if is_universal:
huber_mask_force = get_conditional_huber_force_mask(
output["forces"],
batch["forces"],
huber_delta,
)
cur_f_weights *= huber_mask_force
f_grads = torch.mul(f_grads, cur_f_weights.view(-1, 1, 1)**(0.5))
cur_f_weights = torch.mul(cur_f_weights, huber_mask_force)
f_grads = torch.mul(f_grads, cur_f_weights.unsqueeze(-1)**(0.5))
f_grads = f_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += f_grads.T @ f_grads

if include_virials:
# No Huber mask in the case of virials as it was not used in the
# universal model
cur_v_weights = torch.mul(batch.weight, batch.virials_weight)
v_grads = torch.mul(v_grads, cur_v_weights.view(-1, 1, 1, 1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
cur_v_weights = cur_v_weights.view(-1, 1, 1).expand(-1, 3, 3)
v_grads = torch.mul(v_grads, cur_v_weights.unsqueeze(-1)**(0.5))
v_grads = v_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += v_grads.T @ v_grads

if include_stresses:
# Account for the weighting of structures and targets
# Apply Huber loss mask if universal model
cur_s_weights = torch.mul(batch.weight, batch.stress_weight)
cur_s_weights = cur_s_weights.view(-1, 1, 1).expand(-1, 3, 3)
if is_universal:
huber_mask_stress = get_huber_mask(
output["stress"],
batch["stress"],
huber_delta,
)
cur_s_weights *= huber_mask_stress
s_grads = torch.mul(s_grads, cur_s_weights.view(-1, 1, 1, 1)**(0.5))
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss. Here, we don't normalize
cur_s_weights = torch.mul(cur_s_weights, huber_mask_stress)
s_grads = torch.mul(s_grads, cur_s_weights.unsqueeze(-1)**(0.5))
s_grads = s_grads.reshape(-1, ll_feats.shape[-1])
self.covariance += s_grads.T @ s_grads
# The stresses seem to be normalized by n_atoms in the normal loss, but
# not in the universal loss.
if is_universal:
self.covariance += s_grads.T @ s_grads
else:
# repeat num_atoms for 9 elements of stress tensor
self.covariance += (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1)).T \
@ (s_grads / num_atoms.repeat_interleave(9).unsqueeze(-1))

self.covariance_computed = True

Expand Down
2 changes: 1 addition & 1 deletion mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_huber_mask(
huber_delta: float = 1.0,
) -> torch.Tensor:
se = torch.square(input - target)
huber_mask = (se < huber_delta).int()
huber_mask = (se < huber_delta).int().double()
return huber_mask


Expand Down
11 changes: 5 additions & 6 deletions mace/tools/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ def calibrate_llpr_params(
batch = batch.to(next(model.parameters()).device)
batch_dict = batch.to_dict()
y = batch_dict['energy']
num_graphs = batch_dict["ptr"].numel() - 1
num_atoms.append(batch_dict['ptr'][1:] - batch_dict['ptr'][:-1])
model_outputs = model(batch_dict)
predictions = model_outputs['energy'].detach()
ll_feats.append(model_outputs['ll_feats'].detach())
cur_ll_feats = model.aggregate_ll_features(
model_outputs["node_feats"], batch_dict["batch"], num_graphs
).detach()
ll_feats.append(cur_ll_feats)
actual_errors.append((y - predictions)**2)

actual_errors = torch.cat(actual_errors, dim=0)
ll_feats_all = torch.cat(ll_feats, dim=0)
num_atoms = torch.cat(num_atoms, dim=0)

# Enforce calibration on the extensive uncertainty of validation set
if model.ll_feat_format == "avg":
ll_feats_all = torch.mul(ll_feats_all, num_atoms.unsqueeze(-1))

def obj_function_wrapper(x):
x = _process_inputs(x)
try:
Expand All @@ -57,7 +57,6 @@ def obj_function_wrapper(x):
obj_value = obj_function(actual_errors, predicted_errors, **kwargs)
except torch._C._LinAlgError:
obj_value = 1e10
# HACK:
if math.isnan(obj_value):
obj_value = 1e10
return obj_value
Expand Down

0 comments on commit d11af04

Please sign in to comment.