Skip to content

Commit

Permalink
pt: fix loss training when no data available (#3571)
Browse files Browse the repository at this point in the history
Fix #3482 and #3483.
  • Loading branch information
iProzd authored Mar 22, 2024
1 parent e47478f commit dc14719
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 44 deletions.
59 changes: 40 additions & 19 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(
torch.square(model_pred["energy"] - label["energy"])
)
if not self.inference:
more_loss["l2_ener_loss"] = l2_ener_loss.detach()
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = rmse_e.detach()
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
Expand All @@ -141,24 +147,31 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach()
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = (
torch.mean(torch.abs(model_pred["energy"] - label["energy"]))
* atom_norm
)
more_loss["mae_e"] = mae_e.detach()
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(
torch.abs(model_pred["energy"] - label["energy"])
)
more_loss["mae_e_all"] = mae_e_all.detach()
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)

if self.has_f and "force" in model_pred and "force" in label:
find_force = label.get("find_force", 0.0)
pref_f = pref_f * find_force
if "force_target_mask" in model_pred:
force_target_mask = model_pred["force_target_mask"]
else:
Expand All @@ -174,40 +187,48 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
diff_f = label["force"] - model_pred["force"]
l2_force_loss = torch.mean(torch.square(diff_f))
if not self.inference:
more_loss["l2_force_loss"] = l2_force_loss.detach()
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = rmse_f.detach()
more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force)
else:
l1_force_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
if force_target_mask is not None:
l1_force_loss *= force_target_mask
force_cnt = force_target_mask.squeeze(-1).sum(-1)
more_loss["mae_f"] = (
l1_force_loss.mean(-1).sum(-1) / force_cnt
).mean()
more_loss["mae_f"] = self.display_if_exist(
(l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force
)
l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum()
else:
more_loss["mae_f"] = l1_force_loss.mean().detach()
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
more_loss["mae_f"] = mae_f.detach()
more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force)

if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = l2_virial_loss.detach()
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = rmse_v.detach()
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = mae_v.detach()
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down
65 changes: 48 additions & 17 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(
torch.square(model_pred["energy"] - label["energy"])
)
if not self.inference:
more_loss["l2_ener_loss"] = l2_ener_loss.detach()
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = rmse_e.detach()
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
Expand All @@ -115,44 +121,61 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach()
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
model_pred["energy"].reshape(-1),
label["energy"].reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = (
torch.mean(torch.abs(model_pred["energy"] - label["energy"]))
* atom_norm
)
more_loss["mae_e"] = mae_e.detach()
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(
torch.abs(model_pred["energy"] - label["energy"])
)
more_loss["mae_e_all"] = mae_e_all.detach()
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)

if self.has_fr and "force" in model_pred and "force" in label:
find_force_r = label.get("find_force", 0.0)
pref_fr = pref_fr * find_force_r
if not self.use_l1_all:
diff_fr = label["force"] - model_pred["force"]
l2_force_real_loss = torch.mean(torch.square(diff_fr))
if not self.inference:
more_loss["l2_force_r_loss"] = l2_force_real_loss.detach()
more_loss["l2_force_r_loss"] = self.display_if_exist(
l2_force_real_loss.detach(), find_force_r
)
loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fr = l2_force_real_loss.sqrt()
more_loss["rmse_fr"] = rmse_fr.detach()
more_loss["rmse_fr"] = self.display_if_exist(
rmse_fr.detach(), find_force_r
)
if mae:
mae_fr = torch.mean(torch.abs(diff_fr))
more_loss["mae_fr"] = mae_fr.detach()
more_loss["mae_fr"] = self.display_if_exist(
mae_fr.detach(), find_force_r
)
else:
l1_force_real_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
more_loss["mae_fr"] = l1_force_real_loss.mean().detach()
more_loss["mae_fr"] = self.display_if_exist(
l1_force_real_loss.mean().detach(), find_force_r
)
l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum()
loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)

if self.has_fm and "force_mag" in model_pred and "force_mag" in label:
find_force_m = label.get("find_force_mag", 0.0)
pref_fm = pref_fm * find_force_m
nframes = model_pred["force_mag"].shape[0]
atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3])
label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3)
Expand All @@ -163,18 +186,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
diff_fm = label_force_mag - model_pred_force_mag
l2_force_mag_loss = torch.mean(torch.square(diff_fm))
if not self.inference:
more_loss["l2_force_m_loss"] = l2_force_mag_loss.detach()
more_loss["l2_force_m_loss"] = self.display_if_exist(
l2_force_mag_loss.detach(), find_force_m
)
loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fm = l2_force_mag_loss.sqrt()
more_loss["rmse_fm"] = rmse_fm.detach()
more_loss["rmse_fm"] = self.display_if_exist(
rmse_fm.detach(), find_force_m
)
if mae:
mae_fm = torch.mean(torch.abs(diff_fm))
more_loss["mae_fm"] = mae_fm.detach()
more_loss["mae_fm"] = self.display_if_exist(
mae_fm.detach(), find_force_m
)
else:
l1_force_mag_loss = F.l1_loss(
label_force_mag, model_pred_force_mag, reduction="none"
)
more_loss["mae_fm"] = l1_force_mag_loss.mean().detach()
more_loss["mae_fm"] = self.display_if_exist(
l1_force_mag_loss.mean().detach(), find_force_m
)
l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum()
loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)

Expand Down
13 changes: 13 additions & 0 deletions deepmd/pt/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,16 @@ def forward(self, input_dict, model, label, natoms, learning_rate):
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
pass

@staticmethod
def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor:
"""Display NaN if labeled property is not found.
Parameters
----------
loss : torch.Tensor
the loss tensor
find_property : float
whether the property is found
"""
return loss if bool(find_property) else torch.nan
24 changes: 17 additions & 7 deletions deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
and self.tensor_name in model_pred
and "atomic_" + self.label_name in label
):
find_local = label.get("find_" + "atomic_" + self.label_name, 0.0)
local_weight = self.local_weight * find_local
local_tensor_pred = model_pred[self.tensor_name].reshape(
[-1, natoms, self.tensor_size]
)
Expand All @@ -108,15 +110,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach()
loss += self.local_weight * l2_local_loss
more_loss[f"l2_local_{self.tensor_name}_loss"] = self.display_if_exist(
l2_local_loss.detach(), find_local
)
loss += local_weight * l2_local_loss
rmse_local = l2_local_loss.sqrt()
more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach()
more_loss[f"rmse_local_{self.tensor_name}"] = self.display_if_exist(
rmse_local.detach(), find_local
)
if (
self.has_global_weight
and "global_" + self.tensor_name in model_pred
and self.label_name in label
):
find_global = label.get("find_" + self.label_name, 0.0)
global_weight = self.global_weight * find_global
global_tensor_pred = model_pred["global_" + self.tensor_name].reshape(
[-1, self.tensor_size]
)
Expand All @@ -132,12 +140,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
atom_num = natoms
l2_global_loss = torch.mean(torch.square(diff))
if not self.inference:
more_loss[f"l2_global_{self.tensor_name}_loss"] = (
l2_global_loss.detach()
more_loss[f"l2_global_{self.tensor_name}_loss"] = self.display_if_exist(
l2_global_loss.detach(), find_global
)
loss += self.global_weight * l2_global_loss
loss += global_weight * l2_global_loss
rmse_global = l2_global_loss.sqrt() / atom_num
more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach()
more_loss[f"rmse_global_{self.tensor_name}"] = self.display_if_exist(
rmse_global.detach(), find_global
)
return model_pred, loss, more_loss

@property
Expand Down
3 changes: 2 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def get_data(self, is_train=True, task_key="Default"):
if item_key in input_keys:
input_dict[item_key] = batch_data[item_key]
else:
if item_key not in ["sid", "fid"] and "find_" not in item_key:
if item_key not in ["sid", "fid"]:
label_dict[item_key] = batch_data[item_key]
log_dict = {}
if "fid" in batch_data:
Expand Down Expand Up @@ -1109,6 +1109,7 @@ def print_header(self, fout, train_results, valid_results):
for k in sorted(train_results[model_key].keys()):
print_str += prop_fmt % (k + f"_trn_{model_key}")
print_str += " %8s\n" % "lr"
print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n"
fout.write(print_str)
fout.flush()

Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,9 @@ def test_consistency(self):
}
label = {
"energy": batch["energy"].to(env.DEVICE),
"find_energy": 1.0,
"force": batch["force"].to(env.DEVICE),
"find_force": 1.0,
}
cur_lr = my_lr.value(self.wanted_step)
model_predict, loss, _ = my_loss(
Expand Down
Loading

0 comments on commit dc14719

Please sign in to comment.