diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index edae53a771..ccc23b690c 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -124,15 +124,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + find_energy = label.get("find_energy", 0.0) + pref_e = pref_e * find_energy if not self.use_l1_all: l2_ener_loss = torch.mean( torch.square(model_pred["energy"] - label["energy"]) ) if not self.inference: - more_loss["l2_ener_loss"] = l2_ener_loss.detach() + more_loss["l2_ener_loss"] = self.display_if_exist( + l2_ener_loss.detach(), find_energy + ) loss += atom_norm * (pref_e * l2_ener_loss) rmse_e = l2_ener_loss.sqrt() * atom_norm - more_loss["rmse_e"] = rmse_e.detach() + more_loss["rmse_e"] = self.display_if_exist( + rmse_e.detach(), find_energy + ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( @@ -141,24 +147,31 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): reduction="sum", ) loss += pref_e * l1_ener_loss - more_loss["mae_e"] = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), - reduction="mean", - ).detach() + more_loss["mae_e"] = self.display_if_exist( + F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach(), + find_energy, + ) # more_loss['log_keys'].append('rmse_e') if mae: mae_e = ( torch.mean(torch.abs(model_pred["energy"] - label["energy"])) * atom_norm ) - more_loss["mae_e"] = mae_e.detach() + more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) mae_e_all = torch.mean( torch.abs(model_pred["energy"] - label["energy"]) ) - more_loss["mae_e_all"] = mae_e_all.detach() + more_loss["mae_e_all"] = self.display_if_exist( + mae_e_all.detach(), find_energy + ) if self.has_f and "force" in model_pred and "force" in label: + find_force = label.get("find_force", 0.0) + pref_f = pref_f * find_force if "force_target_mask" in model_pred: force_target_mask = model_pred["force_target_mask"] else: @@ -174,10 +187,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): diff_f = label["force"] - model_pred["force"] l2_force_loss = torch.mean(torch.square(diff_f)) if not self.inference: - more_loss["l2_force_loss"] = l2_force_loss.detach() + more_loss["l2_force_loss"] = self.display_if_exist( + l2_force_loss.detach(), find_force + ) loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_f = l2_force_loss.sqrt() - more_loss["rmse_f"] = rmse_f.detach() + more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force) else: l1_force_loss = F.l1_loss( label["force"], model_pred["force"], reduction="none" @@ -185,29 +200,35 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if force_target_mask is not None: l1_force_loss *= force_target_mask force_cnt = force_target_mask.squeeze(-1).sum(-1) - more_loss["mae_f"] = ( - l1_force_loss.mean(-1).sum(-1) / force_cnt - ).mean() + more_loss["mae_f"] = self.display_if_exist( + (l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force + ) l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum() else: - more_loss["mae_f"] = l1_force_loss.mean().detach() + more_loss["mae_f"] = self.display_if_exist( + l1_force_loss.mean().detach(), find_force + ) l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) if mae: mae_f = torch.mean(torch.abs(diff_f)) - more_loss["mae_f"] = mae_f.detach() + more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force) if self.has_v and "virial" in model_pred and "virial" in label: + find_virial = label.get("find_virial", 0.0) + pref_v = pref_v * find_virial diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) l2_virial_loss = torch.mean(torch.square(diff_v)) if not self.inference: - more_loss["l2_virial_loss"] = l2_virial_loss.detach() + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) loss += atom_norm * (pref_v * l2_virial_loss) rmse_v = l2_virial_loss.sqrt() * atom_norm - more_loss["rmse_v"] = rmse_v.detach() + more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) if mae: mae_v = torch.mean(torch.abs(diff_v)) * atom_norm - more_loss["mae_v"] = mae_v.detach() + more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 1f10e3cf5f..3bd81adf77 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -98,15 +98,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + find_energy = label.get("find_energy", 0.0) + pref_e = pref_e * find_energy if not self.use_l1_all: l2_ener_loss = torch.mean( torch.square(model_pred["energy"] - label["energy"]) ) if not self.inference: - more_loss["l2_ener_loss"] = l2_ener_loss.detach() + more_loss["l2_ener_loss"] = self.display_if_exist( + l2_ener_loss.detach(), find_energy + ) loss += atom_norm * (pref_e * l2_ener_loss) rmse_e = l2_ener_loss.sqrt() * atom_norm - more_loss["rmse_e"] = rmse_e.detach() + more_loss["rmse_e"] = self.display_if_exist( + rmse_e.detach(), find_energy + ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( @@ -115,44 +121,61 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): reduction="sum", ) loss += pref_e * l1_ener_loss - more_loss["mae_e"] = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), - reduction="mean", - ).detach() + more_loss["mae_e"] = self.display_if_exist( + F.l1_loss( + model_pred["energy"].reshape(-1), + label["energy"].reshape(-1), + reduction="mean", + ).detach(), + find_energy, + ) # more_loss['log_keys'].append('rmse_e') if mae: mae_e = ( torch.mean(torch.abs(model_pred["energy"] - label["energy"])) * atom_norm ) - more_loss["mae_e"] = mae_e.detach() + more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) mae_e_all = torch.mean( torch.abs(model_pred["energy"] - label["energy"]) ) - more_loss["mae_e_all"] = mae_e_all.detach() + more_loss["mae_e_all"] = self.display_if_exist( + mae_e_all.detach(), find_energy + ) if self.has_fr and "force" in model_pred and "force" in label: + find_force_r = label.get("find_force", 0.0) + pref_fr = pref_fr * find_force_r if not self.use_l1_all: diff_fr = label["force"] - model_pred["force"] l2_force_real_loss = torch.mean(torch.square(diff_fr)) if not self.inference: - more_loss["l2_force_r_loss"] = l2_force_real_loss.detach() + more_loss["l2_force_r_loss"] = self.display_if_exist( + l2_force_real_loss.detach(), find_force_r + ) loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fr = l2_force_real_loss.sqrt() - more_loss["rmse_fr"] = rmse_fr.detach() + more_loss["rmse_fr"] = self.display_if_exist( + rmse_fr.detach(), find_force_r + ) if mae: mae_fr = torch.mean(torch.abs(diff_fr)) - more_loss["mae_fr"] = mae_fr.detach() + more_loss["mae_fr"] = self.display_if_exist( + mae_fr.detach(), find_force_r + ) else: l1_force_real_loss = F.l1_loss( label["force"], model_pred["force"], reduction="none" ) - more_loss["mae_fr"] = l1_force_real_loss.mean().detach() + more_loss["mae_fr"] = self.display_if_exist( + l1_force_real_loss.mean().detach(), find_force_r + ) l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum() loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION) if self.has_fm and "force_mag" in model_pred and "force_mag" in label: + find_force_m = label.get("find_force_mag", 0.0) + pref_fm = pref_fm * find_force_m nframes = model_pred["force_mag"].shape[0] atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3]) label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3) @@ -163,18 +186,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): diff_fm = label_force_mag - model_pred_force_mag l2_force_mag_loss = torch.mean(torch.square(diff_fm)) if not self.inference: - more_loss["l2_force_m_loss"] = l2_force_mag_loss.detach() + more_loss["l2_force_m_loss"] = self.display_if_exist( + l2_force_mag_loss.detach(), find_force_m + ) loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) rmse_fm = l2_force_mag_loss.sqrt() - more_loss["rmse_fm"] = rmse_fm.detach() + more_loss["rmse_fm"] = self.display_if_exist( + rmse_fm.detach(), find_force_m + ) if mae: mae_fm = torch.mean(torch.abs(diff_fm)) - more_loss["mae_fm"] = mae_fm.detach() + more_loss["mae_fm"] = self.display_if_exist( + mae_fm.detach(), find_force_m + ) else: l1_force_mag_loss = F.l1_loss( label_force_mag, model_pred_force_mag, reduction="none" ) - more_loss["mae_fm"] = l1_force_mag_loss.mean().detach() + more_loss["mae_fm"] = self.display_if_exist( + l1_force_mag_loss.mean().detach(), find_force_m + ) l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) diff --git a/deepmd/pt/loss/loss.py b/deepmd/pt/loss/loss.py index cc253424ca..7e26f6571a 100644 --- a/deepmd/pt/loss/loss.py +++ b/deepmd/pt/loss/loss.py @@ -28,3 +28,16 @@ def forward(self, input_dict, model, label, natoms, learning_rate): def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" pass + + @staticmethod + def display_if_exist(loss: torch.Tensor, find_property: float) -> torch.Tensor: + """Display NaN if labeled property is not found. + + Parameters + ---------- + loss : torch.Tensor + the loss tensor + find_property : float + whether the property is found + """ + return loss if bool(find_property) else torch.nan diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 238e6a7796..3dd91d203e 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -95,6 +95,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False and self.tensor_name in model_pred and "atomic_" + self.label_name in label ): + find_local = label.get("find_" + "atomic_" + self.label_name, 0.0) + local_weight = self.local_weight * find_local local_tensor_pred = model_pred[self.tensor_name].reshape( [-1, natoms, self.tensor_size] ) @@ -108,15 +110,21 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) if not self.inference: - more_loss[f"l2_local_{self.tensor_name}_loss"] = l2_local_loss.detach() - loss += self.local_weight * l2_local_loss + more_loss[f"l2_local_{self.tensor_name}_loss"] = self.display_if_exist( + l2_local_loss.detach(), find_local + ) + loss += local_weight * l2_local_loss rmse_local = l2_local_loss.sqrt() - more_loss[f"rmse_local_{self.tensor_name}"] = rmse_local.detach() + more_loss[f"rmse_local_{self.tensor_name}"] = self.display_if_exist( + rmse_local.detach(), find_local + ) if ( self.has_global_weight and "global_" + self.tensor_name in model_pred and self.label_name in label ): + find_global = label.get("find_" + self.label_name, 0.0) + global_weight = self.global_weight * find_global global_tensor_pred = model_pred["global_" + self.tensor_name].reshape( [-1, self.tensor_size] ) @@ -132,12 +140,14 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False atom_num = natoms l2_global_loss = torch.mean(torch.square(diff)) if not self.inference: - more_loss[f"l2_global_{self.tensor_name}_loss"] = ( - l2_global_loss.detach() + more_loss[f"l2_global_{self.tensor_name}_loss"] = self.display_if_exist( + l2_global_loss.detach(), find_global ) - loss += self.global_weight * l2_global_loss + loss += global_weight * l2_global_loss rmse_global = l2_global_loss.sqrt() / atom_num - more_loss[f"rmse_global_{self.tensor_name}"] = rmse_global.detach() + more_loss[f"rmse_global_{self.tensor_name}"] = self.display_if_exist( + rmse_global.detach(), find_global + ) return model_pred, loss, more_loss @property diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index aa1ec1c206..1bea24d717 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1074,7 +1074,7 @@ def get_data(self, is_train=True, task_key="Default"): if item_key in input_keys: input_dict[item_key] = batch_data[item_key] else: - if item_key not in ["sid", "fid"] and "find_" not in item_key: + if item_key not in ["sid", "fid"]: label_dict[item_key] = batch_data[item_key] log_dict = {} if "fid" in batch_data: @@ -1109,6 +1109,7 @@ def print_header(self, fout, train_results, valid_results): for k in sorted(train_results[model_key].keys()): print_str += prop_fmt % (k + f"_trn_{model_key}") print_str += " %8s\n" % "lr" + print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" fout.write(print_str) fout.flush() diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index aa1c0dd969..493d6e2cc3 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -352,7 +352,9 @@ def test_consistency(self): } label = { "energy": batch["energy"].to(env.DEVICE), + "find_energy": 1.0, "force": batch["force"].to(env.DEVICE), + "find_force": 1.0, } cur_lr = my_lr.value(self.wanted_step) model_predict, loss, _ = my_loss( diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 2abb22c2a9..17b05dadc6 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -147,6 +147,14 @@ def setUp(self): "virial": torch.from_numpy(p_virial), } self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force), + "find_force": 1.0, + "virial": torch.from_numpy(l_virial), + "find_virial": 1.0, + } + self.label_absent = { "energy": torch.from_numpy(l_energy), "force": torch.from_numpy(l_force), "virial": torch.from_numpy(l_virial), @@ -182,14 +190,24 @@ def fake_model(): self.nloc, self.cur_lr, ) + _, my_loss_absent, my_more_loss_absent = mine( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) my_loss = my_loss.detach().cpu() + my_loss_absent = my_loss_absent.detach().cpu() self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) for key in ["ener", "force", "virial"]: self.assertTrue( np.allclose( base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] ) ) + self.assertTrue(np.isnan(my_more_loss_absent["l2_%s_loss" % key])) class TestEnerSpinLoss(unittest.TestCase): @@ -326,6 +344,14 @@ def setUp(self): ), } self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), + "find_force": 1.0, + "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), + "find_force_mag": 1.0, + } + self.label_absent = { "energy": torch.from_numpy(l_energy), "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), @@ -361,14 +387,24 @@ def fake_model(): self.nloc_tf, # use tf natoms pref self.cur_lr, ) + _, my_loss_absent, my_more_loss_absent = mine( + {}, + fake_model, + self.label_absent, + self.nloc_tf, # use tf natoms pref + self.cur_lr, + ) my_loss = my_loss.detach().cpu() + my_loss_absent = my_loss_absent.detach().cpu() self.assertTrue(np.allclose(base_loss, my_loss.numpy())) + self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) for key in ["ener", "force_r", "force_m"]: self.assertTrue( np.allclose( base_more_loss["l2_%s_loss" % key], my_more_loss["l2_%s_loss" % key] ) ) + self.assertTrue(np.isnan(my_more_loss_absent["l2_%s_loss" % key])) if __name__ == "__main__":