From a676bf2396181d2c1c10273790541def6bbabd4b Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 16 May 2024 10:06:32 +0800 Subject: [PATCH] feat(pt): support complete form energy loss (#3782) Support atomic energy, atomic prefactor force, generalized force, `relative_f`, `enable_atom_ener_coeff` for `EnergyStdLoss`. Support atomic energy, `enable_atom_ener_coeff` for `EnergySpinLoss`. virial support for `EnergySpinLoss` needs discussion and another PR. ## Summary by CodeRabbit - **New Features** - Introduced new parameters for enhanced energy loss computation, including `relative_f`, `enable_atom_ener_coeff`, and `numb_generalized_coord`. - Improved handling of atomic energy loss with the addition of `pref_ae` calculation. - **Bug Fixes** - Refined conditional logic for energy computation to ensure accurate handling of new parameters. - **Tests** - Expanded test coverage with new classes and methods to validate the new features and ensure consistency. --- deepmd/pt/loss/ener.py | 225 +++++++--- deepmd/pt/loss/ener_spin.py | 113 +++-- source/tests/pt/test_loss.py | 808 ++++++++++++++++++++++++++--------- 3 files changed, 854 insertions(+), 292 deletions(-) diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index ccc23b690c..97e329935a 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( List, + Optional, ) import torch @@ -34,6 +35,11 @@ def __init__( limit_pref_ae: float = 0.0, start_pref_pf: float = 0.0, limit_pref_pf: float = 0.0, + relative_f: Optional[float] = None, + enable_atom_ener_coeff: bool = False, + start_pref_gf: float = 0.0, + limit_pref_gf: float = 0.0, + numb_generalized_coord: int = 0, use_l1_all: bool = False, inference=False, **kwargs, @@ -64,6 +70,18 @@ def __init__( The prefactor of atomic prefactor force loss at the start of the training. limit_pref_pf : float The prefactor of atomic prefactor force loss at the end of the training. + relative_f : float + If provided, relative force error will be used in the loss. The difference + of force will be normalized by the magnitude of the force in the label with + a shift given by relative_f + enable_atom_ener_coeff : bool + if true, the energy will be computed as \sum_i c_i E_i + start_pref_gf : float + The prefactor of generalized force loss at the start of the training. + limit_pref_gf : float + The prefactor of generalized force loss at the end of the training. + numb_generalized_coord : int + The dimension of generalized coordinates. use_l1_all : bool Whether to use L1 loss, if False (default), it will use L2 loss. inference : bool @@ -76,10 +94,9 @@ def __init__( self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference - - # TODO EnergyStdLoss need support for atomic energy and atomic pref self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference + self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e @@ -87,6 +104,19 @@ def __init__( self.limit_pref_f = limit_pref_f self.start_pref_v = start_pref_v self.limit_pref_v = limit_pref_v + self.start_pref_ae = start_pref_ae + self.limit_pref_ae = limit_pref_ae + self.start_pref_pf = start_pref_pf + self.limit_pref_pf = limit_pref_pf + self.start_pref_gf = start_pref_gf + self.limit_pref_gf = limit_pref_gf + self.relative_f = relative_f + self.enable_atom_ener_coeff = enable_atom_ener_coeff + self.numb_generalized_coord = numb_generalized_coord + if self.has_gf and self.numb_generalized_coord < 1: + raise RuntimeError( + "When generalized force loss is used, the dimension of generalized coordinates should be larger than 0" + ) self.use_l1_all = use_l1_all self.inference = inference @@ -118,18 +148,35 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef + pref_ae = self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * coef + pref_pf = self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * coef + pref_gf = self.limit_pref_gf + (self.start_pref_gf - self.limit_pref_gf) * coef + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} # more_loss['log_keys'] = [] # showed when validation on the fly # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + energy_pred = model_pred["energy"] + energy_label = label["energy"] + if self.enable_atom_ener_coeff and "atom_energy" in model_pred: + atom_ener_pred = model_pred["atom_energy"] + # when ener_coeff (\nu) is defined, the energy is defined as + # E = \sum_i \nu_i E_i + # instead of the sum of atomic energies. + # + # A case is that we want to train reaction energy + # A + B -> C + D + # E = - E(A) - E(B) + E(C) + E(D) + # A, B, C, D could be put far away from each other + atom_ener_coeff = label["atom_ener_coeff"] + atom_ener_coeff = atom_ener_coeff.reshape(atom_ener_pred.shape) + energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1) find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy if not self.use_l1_all: - l2_ener_loss = torch.mean( - torch.square(model_pred["energy"] - label["energy"]) - ) + l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label)) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy @@ -142,77 +189,111 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), + energy_pred.reshape(-1), + energy_label.reshape(-1), reduction="sum", ) loss += pref_e * l1_ener_loss more_loss["mae_e"] = self.display_if_exist( F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), + energy_pred.reshape(-1), + energy_label.reshape(-1), reduction="mean", ).detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') if mae: - mae_e = ( - torch.mean(torch.abs(model_pred["energy"] - label["energy"])) - * atom_norm - ) + mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean( - torch.abs(model_pred["energy"] - label["energy"]) - ) + mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) more_loss["mae_e_all"] = self.display_if_exist( mae_e_all.detach(), find_energy ) - if self.has_f and "force" in model_pred and "force" in label: + if ( + (self.has_f or self.has_pf or self.relative_f or self.has_gf) + and "force" in model_pred + and "force" in label + ): find_force = label.get("find_force", 0.0) pref_f = pref_f * find_force - if "force_target_mask" in model_pred: - force_target_mask = model_pred["force_target_mask"] - else: - force_target_mask = None - if not self.use_l1_all: - if force_target_mask is not None: - diff_f = (label["force"] - model_pred["force"]) * force_target_mask - force_cnt = force_target_mask.squeeze(-1).sum(-1) - l2_force_loss = torch.mean( - torch.square(diff_f).mean(-1).sum(-1) / force_cnt - ) - else: - diff_f = label["force"] - model_pred["force"] + force_pred = model_pred["force"] + force_label = label["force"] + diff_f = (force_label - force_pred).reshape(-1) + + if self.relative_f is not None: + force_label_3 = force_label.reshape(-1, 3) + norm_f = force_label_3.norm(dim=1, keepdim=True) + self.relative_f + diff_f_3 = diff_f.reshape(-1, 3) + diff_f_3 = diff_f_3 / norm_f + diff_f = diff_f_3.reshape(-1) + + if self.has_f: + if not self.use_l1_all: l2_force_loss = torch.mean(torch.square(diff_f)) - if not self.inference: - more_loss["l2_force_loss"] = self.display_if_exist( - l2_force_loss.detach(), find_force + if not self.inference: + more_loss["l2_force_loss"] = self.display_if_exist( + l2_force_loss.detach(), find_force + ) + loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_f = l2_force_loss.sqrt() + more_loss["rmse_f"] = self.display_if_exist( + rmse_f.detach(), find_force ) - loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) - rmse_f = l2_force_loss.sqrt() - more_loss["rmse_f"] = self.display_if_exist(rmse_f.detach(), find_force) - else: - l1_force_loss = F.l1_loss( - label["force"], model_pred["force"], reduction="none" - ) - if force_target_mask is not None: - l1_force_loss *= force_target_mask - force_cnt = force_target_mask.squeeze(-1).sum(-1) - more_loss["mae_f"] = self.display_if_exist( - (l1_force_loss.mean(-1).sum(-1) / force_cnt).mean(), find_force - ) - l1_force_loss = (l1_force_loss.sum(-1).sum(-1) / force_cnt).sum() else: + l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none") more_loss["mae_f"] = self.display_if_exist( l1_force_loss.mean().detach(), find_force ) l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() - loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) - if mae: - mae_f = torch.mean(torch.abs(diff_f)) - more_loss["mae_f"] = self.display_if_exist(mae_f.detach(), find_force) + loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + if mae: + mae_f = torch.mean(torch.abs(diff_f)) + more_loss["mae_f"] = self.display_if_exist( + mae_f.detach(), find_force + ) + + if self.has_pf and "atom_pref" in label: + atom_pref = label["atom_pref"] + find_atom_pref = label.get("find_atom_pref", 0.0) + pref_pf = pref_pf * find_atom_pref + atom_pref_reshape = atom_pref.reshape(-1) + l2_pref_force_loss = (torch.square(diff_f) * atom_pref_reshape).mean() + if not self.inference: + more_loss["l2_pref_force_loss"] = self.display_if_exist( + l2_pref_force_loss.detach(), find_atom_pref + ) + loss += (pref_pf * l2_pref_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_pf = l2_pref_force_loss.sqrt() + more_loss["rmse_pf"] = self.display_if_exist( + rmse_pf.detach(), find_atom_pref + ) + + if self.has_gf and "drdq" in label: + drdq = label["drdq"] + find_drdq = label.get("find_drdq", 0.0) + pref_gf = pref_gf * find_drdq + force_reshape_nframes = force_pred.reshape(-1, natoms * 3) + force_label_reshape_nframes = force_label.reshape(-1, natoms * 3) + drdq_reshape = drdq.reshape(-1, natoms * 3, self.numb_generalized_coord) + gen_force_label = torch.einsum( + "bij,bi->bj", drdq_reshape, force_label_reshape_nframes + ) + gen_force = torch.einsum( + "bij,bi->bj", drdq_reshape, force_reshape_nframes + ) + diff_gen_force = gen_force_label - gen_force + l2_gen_force_loss = torch.square(diff_gen_force).mean() + if not self.inference: + more_loss["l2_gen_force_loss"] = self.display_if_exist( + l2_gen_force_loss.detach(), find_drdq + ) + loss += (pref_gf * l2_gen_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_gf = l2_gen_force_loss.sqrt() + more_loss["rmse_gf"] = self.display_if_exist( + rmse_gf.detach(), find_drdq + ) if self.has_v and "virial" in model_pred and "virial" in label: find_virial = label.get("find_virial", 0.0) @@ -229,6 +310,27 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if mae: mae_v = torch.mean(torch.abs(diff_v)) * atom_norm more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) + + if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: + atom_ener = model_pred["atom_energy"] + atom_ener_label = label["atom_ener"] + find_atom_ener = label.get("find_atom_ener", 0.0) + pref_ae = pref_ae * find_atom_ener + atom_ener_reshape = atom_ener.reshape(-1) + atom_ener_label_reshape = atom_ener_label.reshape(-1) + l2_atom_ener_loss = torch.square( + atom_ener_label_reshape - atom_ener_reshape + ).mean() + if not self.inference: + more_loss["l2_atom_ener_loss"] = self.display_if_exist( + l2_atom_ener_loss.detach(), find_atom_ener + ) + loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_ae = l2_atom_ener_loss.sqrt() + more_loss["rmse_ae"] = self.display_if_exist( + rmse_ae.detach(), find_atom_ener + ) + if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss @@ -288,4 +390,25 @@ def label_requirement(self) -> List[DataRequirementItem]: repeat=3, ) ) + if self.has_gf > 0: + label_requirement.append( + DataRequirementItem( + "drdq", + ndof=self.numb_generalized_coord * 3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.enable_atom_ener_coeff: + label_requirement.append( + DataRequirementItem( + "atom_ener_coeff", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) return label_requirement diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 3bd81adf77..78210a778b 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -34,23 +34,53 @@ def __init__( limit_pref_v=0.0, start_pref_ae: float = 0.0, limit_pref_ae: float = 0.0, - start_pref_pf: float = 0.0, - limit_pref_pf: float = 0.0, + enable_atom_ener_coeff: bool = False, use_l1_all: bool = False, inference=False, **kwargs, ): - """Construct a layer to compute loss on energy, real force, magnetic force and virial.""" + r"""Construct a layer to compute loss on energy, real force, magnetic force and virial. + + Parameters + ---------- + starter_learning_rate : float + The learning rate at the start of the training. + start_pref_e : float + The prefactor of energy loss at the start of the training. + limit_pref_e : float + The prefactor of energy loss at the end of the training. + start_pref_fr : float + The prefactor of real force loss at the start of the training. + limit_pref_fr : float + The prefactor of real force loss at the end of the training. + start_pref_fm : float + The prefactor of magnetic force loss at the start of the training. + limit_pref_fm : float + The prefactor of magnetic force loss at the end of the training. + start_pref_v : float + The prefactor of virial loss at the start of the training. + limit_pref_v : float + The prefactor of virial loss at the end of the training. + start_pref_ae : float + The prefactor of atomic energy loss at the start of the training. + limit_pref_ae : float + The prefactor of atomic energy loss at the end of the training. + enable_atom_ener_coeff : bool + if true, the energy will be computed as \sum_i c_i E_i + use_l1_all : bool + Whether to use L1 loss, if False (default), it will use L2 loss. + inference : bool + If true, it will output all losses found in output, ignoring the pre-factors. + **kwargs + Other keyword arguments. + """ super().__init__() self.starter_learning_rate = starter_learning_rate self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference self.has_fr = (start_pref_fr != 0.0 and limit_pref_fr != 0.0) or inference self.has_fm = (start_pref_fm != 0.0 and limit_pref_fm != 0.0) or inference - - # TODO EnergySpinLoss needs support for virial, atomic energy and atomic pref self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference - self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e @@ -60,6 +90,9 @@ def __init__( self.limit_pref_fm = limit_pref_fm self.start_pref_v = start_pref_v self.limit_pref_v = limit_pref_v + self.start_pref_ae = start_pref_ae + self.limit_pref_ae = limit_pref_ae + self.enable_atom_ener_coeff = enable_atom_ener_coeff self.use_l1_all = use_l1_all self.inference = inference @@ -92,18 +125,32 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef pref_fm = self.limit_pref_fm + (self.start_pref_fm - self.limit_pref_fm) * coef pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef + pref_ae = self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * coef loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) more_loss = {} # more_loss['log_keys'] = [] # showed when validation on the fly # more_loss['test_keys'] = [] # showed when doing dp test atom_norm = 1.0 / natoms if self.has_e and "energy" in model_pred and "energy" in label: + energy_pred = model_pred["energy"] + energy_label = label["energy"] + if self.enable_atom_ener_coeff and "atom_energy" in model_pred: + atom_ener_pred = model_pred["atom_energy"] + # when ener_coeff (\nu) is defined, the energy is defined as + # E = \sum_i \nu_i E_i + # instead of the sum of atomic energies. + # + # A case is that we want to train reaction energy + # A + B -> C + D + # E = - E(A) - E(B) + E(C) + E(D) + # A, B, C, D could be put far away from each other + atom_ener_coeff = label["atom_ener_coeff"] + atom_ener_coeff = atom_ener_coeff.reshape(atom_ener_pred.shape) + energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1) find_energy = label.get("find_energy", 0.0) pref_e = pref_e * find_energy if not self.use_l1_all: - l2_ener_loss = torch.mean( - torch.square(model_pred["energy"] - label["energy"]) - ) + l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label)) if not self.inference: more_loss["l2_ener_loss"] = self.display_if_exist( l2_ener_loss.detach(), find_energy @@ -116,29 +163,24 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms l1_ener_loss = F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), + energy_pred.reshape(-1), + energy_label.reshape(-1), reduction="sum", ) loss += pref_e * l1_ener_loss more_loss["mae_e"] = self.display_if_exist( F.l1_loss( - model_pred["energy"].reshape(-1), - label["energy"].reshape(-1), + energy_pred.reshape(-1), + energy_label.reshape(-1), reduction="mean", ).detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') if mae: - mae_e = ( - torch.mean(torch.abs(model_pred["energy"] - label["energy"])) - * atom_norm - ) + mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean( - torch.abs(model_pred["energy"] - label["energy"]) - ) + mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) more_loss["mae_e_all"] = self.display_if_exist( mae_e_all.detach(), find_energy ) @@ -209,6 +251,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum() loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION) + if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: + atom_ener = model_pred["atom_energy"] + atom_ener_label = label["atom_ener"] + find_atom_ener = label.get("find_atom_ener", 0.0) + pref_ae = pref_ae * find_atom_ener + atom_ener_reshape = atom_ener.reshape(-1) + atom_ener_label_reshape = atom_ener_label.reshape(-1) + l2_atom_ener_loss = torch.square( + atom_ener_label_reshape - atom_ener_reshape + ).mean() + if not self.inference: + more_loss["l2_atom_ener_loss"] = self.display_if_exist( + l2_atom_ener_loss.detach(), find_atom_ener + ) + loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PT_FLOAT_PRECISION) + rmse_ae = l2_atom_ener_loss.sqrt() + more_loss["rmse_ae"] = self.display_if_exist( + rmse_ae.detach(), find_atom_ener + ) + if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss @@ -267,15 +329,4 @@ def label_requirement(self) -> List[DataRequirementItem]: high_prec=False, ) ) - if self.has_pf: - label_requirement.append( - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ) - ) return label_requirement diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index dcd59cd56e..72ea961c37 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -7,9 +7,6 @@ import torch tf.disable_eager_execution() -from copy import ( - deepcopy, -) from pathlib import ( Path, ) @@ -46,48 +43,102 @@ def get_batch(system, type_map, data_requirement): return np_batch, pt_batch -class TestEnerStdLoss(unittest.TestCase): +class LossCommonTest(unittest.TestCase): def setUp(self): - self.system = str(Path(__file__).parent / "water/data/data_0") - self.type_map = ["H", "O"] - self.start_lr = 1.1 - self.start_pref_e = 0.02 - self.limit_pref_e = 1.0 - self.start_pref_f = 1000.0 - self.limit_pref_f = 1.0 - self.start_pref_v = 0.02 - self.limit_pref_v = 1.0 self.cur_lr = 1.2 + if not self.spin: + self.system = str(Path(__file__).parent / "water/data/data_0") + self.type_map = ["H", "O"] + else: + self.system = str(Path(__file__).parent / "NiO/data/data_0") + self.type_map = ["Ni", "O"] + energy_data_requirement.append( + DataRequirementItem( + "force_mag", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) # data np_batch, pt_batch = get_batch( self.system, self.type_map, energy_data_requirement ) natoms = np_batch["natoms"] self.nloc = natoms[0] - l_energy, l_force, l_virial = ( - np_batch["energy"], - np_batch["force"], - np_batch["virial"], - ) - p_energy, p_force, p_virial = ( - np.ones_like(l_energy), - np.ones_like(l_force), - np.ones_like(l_virial), - ) - nloc = natoms[0] - batch_size = pt_batch["coord"].shape[0] - atom_energy = np.zeros(shape=[batch_size, nloc]) - atom_pref = np.zeros(shape=[batch_size, nloc * 3]) + nframes = np_batch["energy"].shape[0] + rng = np.random.default_rng() + + if not self.spin: + l_energy, l_force, l_virial = ( + np_batch["energy"], + np_batch["force"], + np_batch["virial"], + ) + p_energy, p_force, p_virial = ( + np.ones_like(l_energy), + np.ones_like(l_force), + np.ones_like(l_virial), + ) + nloc = natoms[0] + batch_size = pt_batch["coord"].shape[0] + p_atom_energy = rng.random(size=[batch_size, nloc]) + l_atom_energy = rng.random(size=[batch_size, nloc]) + atom_pref = rng.random(size=[batch_size, nloc * 3]) + drdq = rng.random(size=[batch_size, nloc * 2 * 3]) + atom_ener_coeff = rng.random(size=[batch_size, nloc]) + # placeholders + l_force_real = l_force + l_force_mag = l_force + p_force_real = p_force + p_force_mag = p_force + else: + # data + np_batch, pt_batch = get_batch( + self.system, self.type_map, energy_data_requirement + ) + natoms = np_batch["natoms"] + self.nloc = natoms[0] + l_energy, l_force_real, l_force_mag, l_virial = ( + np_batch["energy"], + np_batch["force"], + np_batch["force_mag"], + np_batch["virial"], + ) + # merged force for tf old implement + l_force_merge_tf = np.concatenate( + [ + l_force_real.reshape(nframes, self.nloc, 3), + l_force_mag.reshape(nframes, self.nloc, 3)[ + np_batch["atype"] == 0 + ].reshape(nframes, -1, 3), + ], + axis=1, + ).reshape(nframes, -1) + p_energy, p_force_real, p_force_mag, p_force_merge_tf, p_virial = ( + np.ones_like(l_energy), + np.ones_like(l_force_real), + np.ones_like(l_force_mag), + np.ones_like(l_force_merge_tf), + np.ones_like(l_virial), + ) + virt_nloc = (np_batch["atype"] == 0).sum(-1) + natoms_tf = np.concatenate([natoms, virt_nloc], axis=0) + natoms_tf[:2] += virt_nloc + nloc = natoms_tf[0] + batch_size = pt_batch["coord"].shape[0] + p_atom_energy = rng.random(size=[batch_size, nloc]) + l_atom_energy = rng.random(size=[batch_size, nloc]) + atom_pref = rng.random(size=[batch_size, nloc * 3]) + drdq = rng.random(size=[batch_size, nloc * 2 * 3]) + atom_ener_coeff = rng.random(size=[batch_size, nloc]) + self.nloc_tf = nloc + natoms = natoms_tf + l_force = l_force_merge_tf + p_force = p_force_merge_tf + # tf - base = EnerStdLoss( - self.start_lr, - self.start_pref_e, - self.limit_pref_e, - self.start_pref_f, - self.limit_pref_f, - self.start_pref_v, - self.limit_pref_v, - ) self.g = tf.Graph() with self.g.as_default(): t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) @@ -101,11 +152,15 @@ def setUp(self): t_lvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) t_latom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) t_atom_pref = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_ener_coeff = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_drdq = tf.placeholder(shape=[None, None], dtype=tf.float64) find_energy = tf.constant(1.0, dtype=tf.float64) find_force = tf.constant(1.0, dtype=tf.float64) - find_virial = tf.constant(1.0, dtype=tf.float64) - find_atom_energy = tf.constant(0.0, dtype=tf.float64) - find_atom_pref = tf.constant(0.0, dtype=tf.float64) + find_virial = tf.constant(1.0 if not self.spin else 0.0, dtype=tf.float64) + find_atom_energy = tf.constant(1.0, dtype=tf.float64) + find_atom_pref = tf.constant(1.0, dtype=tf.float64) + find_drdq = tf.constant(1.0, dtype=tf.float64) + find_atom_ener_coeff = tf.constant(1.0, dtype=tf.float64) model_dict = { "energy": t_penergy, "force": t_pforce, @@ -118,59 +173,359 @@ def setUp(self): "virial": t_lvirial, "atom_ener": t_latom_energy, "atom_pref": t_atom_pref, + "drdq": t_drdq, + "atom_ener_coeff": t_atom_ener_coeff, "find_energy": find_energy, "find_force": find_force, "find_virial": find_virial, "find_atom_ener": find_atom_energy, "find_atom_pref": find_atom_pref, + "find_drdq": find_drdq, + "find_atom_ener_coeff": find_atom_ener_coeff, } - self.base_loss_sess = base.build( + self.tf_loss_sess = self.tf_loss.build( t_cur_lr, t_natoms, model_dict, label_dict, "" ) - # torch + self.feed_dict = { t_cur_lr: self.cur_lr, t_natoms: natoms, t_penergy: p_energy, t_pforce: p_force, t_pvirial: p_virial.reshape(-1, 9), - t_patom_energy: atom_energy, + t_patom_energy: p_atom_energy, t_lenergy: l_energy, t_lforce: l_force, t_lvirial: l_virial.reshape(-1, 9), - t_latom_energy: atom_energy, + t_latom_energy: l_atom_energy, t_atom_pref: atom_pref, + t_drdq: drdq, + t_atom_ener_coeff: atom_ener_coeff, } - self.model_pred = { - "energy": torch.from_numpy(p_energy), - "force": torch.from_numpy(p_force), - "virial": torch.from_numpy(p_virial), - } - self.label = { - "energy": torch.from_numpy(l_energy), - "find_energy": 1.0, - "force": torch.from_numpy(l_force), - "find_force": 1.0, - "virial": torch.from_numpy(l_virial), - "find_virial": 1.0, - } - self.label_absent = { - "energy": torch.from_numpy(l_energy), - "force": torch.from_numpy(l_force), - "virial": torch.from_numpy(l_virial), - } + # pt + if not self.spin: + self.model_pred = { + "energy": torch.from_numpy(p_energy), + "force": torch.from_numpy(p_force), + "virial": torch.from_numpy(p_virial), + "atom_energy": torch.from_numpy(p_atom_energy), + } + self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force), + "find_force": 1.0, + "virial": torch.from_numpy(l_virial), + "find_virial": 1.0, + "atom_ener": torch.from_numpy(l_atom_energy), + "find_atom_ener": 1.0, + "atom_pref": torch.from_numpy(atom_pref), + "find_atom_pref": 1.0, + "drdq": torch.from_numpy(drdq), + "find_drdq": 1.0, + "atom_ener_coeff": torch.from_numpy(atom_ener_coeff), + "find_atom_ener_coeff": 1.0, + } + self.label_absent = { + "energy": torch.from_numpy(l_energy), + "force": torch.from_numpy(l_force), + "virial": torch.from_numpy(l_virial), + "atom_ener": torch.from_numpy(l_atom_energy), + "atom_pref": torch.from_numpy(atom_pref), + "drdq": torch.from_numpy(drdq), + "atom_ener_coeff": torch.from_numpy(atom_ener_coeff), + } + else: + self.model_pred = { + "energy": torch.from_numpy(p_energy), + "force": torch.from_numpy(p_force_real).reshape(nframes, self.nloc, 3), + "force_mag": torch.from_numpy(p_force_mag).reshape( + nframes, self.nloc, 3 + ), + "mask_mag": torch.from_numpy(np_batch["atype"] == 0).reshape( + nframes, self.nloc, 1 + ), + "atom_energy": torch.from_numpy(p_atom_energy), + } + self.label = { + "energy": torch.from_numpy(l_energy), + "find_energy": 1.0, + "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), + "find_force": 1.0, + "force_mag": torch.from_numpy(l_force_mag).reshape( + nframes, self.nloc, 3 + ), + "find_force_mag": 1.0, + "atom_ener": torch.from_numpy(l_atom_energy), + "find_atom_ener": 1.0, + "atom_ener_coeff": torch.from_numpy(atom_ener_coeff), + "find_atom_ener_coeff": 1.0, + } + self.label_absent = { + "energy": torch.from_numpy(l_energy), + "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), + "force_mag": torch.from_numpy(l_force_mag).reshape( + nframes, self.nloc, 3 + ), + "atom_ener": torch.from_numpy(l_atom_energy), + "atom_ener_coeff": torch.from_numpy(atom_ener_coeff), + } self.natoms = pt_batch["natoms"] def tearDown(self) -> None: tf.reset_default_graph() return super().tearDown() + +class TestEnerStdLoss(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + # tf + self.tf_loss = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + # pt + self.pt_loss = EnergyStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + ) + self.spin = False + super().setUp() + def test_consistency(self): with tf.Session(graph=self.g) as sess: - base_loss, base_more_loss = sess.run( - self.base_loss_sess, feed_dict=self.feed_dict + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["ener", "force", "virial"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) + ) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + + +class TestEnerStdLossAePfGf(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + self.start_pref_ae = 0.02 + self.limit_pref_ae = 1.0 + self.start_pref_pf = 0.02 + self.limit_pref_pf = 1.0 + self.start_pref_gf = 0.02 + self.limit_pref_gf = 1.0 + self.numb_generalized_coord = 2 + # tf + self.tf_loss = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + self.start_pref_ae, + self.limit_pref_ae, + self.start_pref_pf, + self.limit_pref_pf, + start_pref_gf=self.start_pref_gf, + limit_pref_gf=self.limit_pref_gf, + numb_generalized_coord=self.numb_generalized_coord, + ) + # pt + self.pt_loss = EnergyStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + self.start_pref_ae, + self.limit_pref_ae, + self.start_pref_pf, + self.limit_pref_pf, + start_pref_gf=self.start_pref_gf, + limit_pref_gf=self.limit_pref_gf, + numb_generalized_coord=self.numb_generalized_coord, + ) + self.spin = False + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["ener", "force", "virial", "atom_ener", "pref_force", "gen_force"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) + ) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + + +class TestEnerStdLossAecoeff(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + # tf + self.tf_loss = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + enable_atom_ener_coeff=True, + ) + # pt + self.pt_loss = EnergyStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + enable_atom_ener_coeff=True, + ) + self.spin = False + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["ener", "force", "virial"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) ) - mine = EnergyStdLoss( + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + + +class TestEnerStdLossRelativeF(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_f = 1000.0 + self.limit_pref_f = 1.0 + self.start_pref_v = 0.02 + self.limit_pref_v = 1.0 + # tf + self.tf_loss = EnerStdLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_f, + self.limit_pref_f, + self.start_pref_v, + self.limit_pref_v, + relative_f=0.1, + ) + # pt + self.pt_loss = EnergyStdLoss( self.start_lr, self.start_pref_e, self.limit_pref_e, @@ -178,42 +533,49 @@ def test_consistency(self): self.limit_pref_f, self.start_pref_v, self.limit_pref_v, + relative_f=0.1, ) + self.spin = False + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) def fake_model(): return self.model_pred - _, my_loss, my_more_loss = mine( + _, pt_loss, pt_more_loss = self.pt_loss( {}, fake_model, self.label, self.nloc, self.cur_lr, ) - _, my_loss_absent, my_more_loss_absent = mine( + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( {}, fake_model, self.label_absent, self.nloc, self.cur_lr, ) - my_loss = my_loss.detach().cpu() - my_loss_absent = my_loss_absent.detach().cpu() - self.assertTrue(np.allclose(base_loss, my_loss.numpy())) - self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) for key in ["ener", "force", "virial"]: self.assertTrue( np.allclose( - base_more_loss[f"l2_{key}_loss"], my_more_loss[f"l2_{key}_loss"] + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] ) ) - self.assertTrue(np.isnan(my_more_loss_absent[f"l2_{key}_loss"])) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) -class TestEnerSpinLoss(unittest.TestCase): +class TestEnerSpinLoss(LossCommonTest): def setUp(self): - self.system = str(Path(__file__).parent / "NiO/data/data_0") - self.type_map = ["Ni", "O"] self.start_lr = 1.1 self.start_pref_e = 0.02 self.limit_pref_e = 1.0 @@ -223,56 +585,81 @@ def setUp(self): self.limit_pref_fm = 1.0 self.cur_lr = 1.2 self.use_spin = [1, 0] - # data - spin_data_requirement = deepcopy(energy_data_requirement) - spin_data_requirement.append( - DataRequirementItem( - "force_mag", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ) + # tf + self.tf_loss = EnerSpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, + use_spin=self.use_spin, ) - np_batch, pt_batch = get_batch( - self.system, self.type_map, spin_data_requirement + # pt + self.pt_loss = EnergySpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, ) - natoms = np_batch["natoms"] - self.nloc = natoms[0] - nframes = np_batch["energy"].shape[0] - l_energy, l_force_real, l_force_mag, l_virial = ( - np_batch["energy"], - np_batch["force"], - np_batch["force_mag"], - np_batch["virial"], + self.spin = True + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc_tf, # use tf natoms pref + self.cur_lr, ) - # merged force for tf old implement - l_force_merge_tf = np.concatenate( - [ - l_force_real.reshape(nframes, self.nloc, 3), - l_force_mag.reshape(nframes, self.nloc, 3)[ - np_batch["atype"] == 0 - ].reshape(nframes, -1, 3), - ], - axis=1, - ).reshape(nframes, -1) - p_energy, p_force_real, p_force_mag, p_force_merge_tf, p_virial = ( - np.ones_like(l_energy), - np.ones_like(l_force_real), - np.ones_like(l_force_mag), - np.ones_like(l_force_merge_tf), - np.ones_like(l_virial), + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc_tf, # use tf natoms pref + self.cur_lr, ) - virt_nloc = (np_batch["atype"] == 0).sum(-1) - natoms_tf = np.concatenate([natoms, virt_nloc], axis=0) - natoms_tf[:2] += virt_nloc - nloc = natoms_tf[0] - batch_size = pt_batch["coord"].shape[0] - atom_energy = np.zeros(shape=[batch_size, nloc]) - atom_pref = np.zeros(shape=[batch_size, nloc * 3]) - self.nloc_tf = nloc + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["ener", "force_r", "force_m"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) + ) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + + +class TestEnerSpinLossAe(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_fr = 1000.0 + self.limit_pref_fr = 1.0 + self.start_pref_fm = 1000.0 + self.limit_pref_fm = 1.0 + self.start_pref_ae = 0.02 + self.limit_pref_ae = 1.0 + self.cur_lr = 1.2 + self.use_spin = [1, 0] # tf - base = EnerSpinLoss( + self.tf_loss = EnerSpinLoss( self.start_lr, self.start_pref_e, self.limit_pref_e, @@ -280,94 +667,74 @@ def setUp(self): self.limit_pref_fr, self.start_pref_fm, self.limit_pref_fm, + start_pref_ae=self.start_pref_ae, + limit_pref_ae=self.limit_pref_ae, use_spin=self.use_spin, ) - self.g = tf.Graph() - with self.g.as_default(): - t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) - t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) - t_penergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) - t_pforce = tf.placeholder(shape=[None, None], dtype=tf.float64) - t_pvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) - t_patom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) - t_lenergy = tf.placeholder(shape=[None, 1], dtype=tf.float64) - t_lforce = tf.placeholder(shape=[None, None], dtype=tf.float64) - t_lvirial = tf.placeholder(shape=[None, 9], dtype=tf.float64) - t_latom_energy = tf.placeholder(shape=[None, None], dtype=tf.float64) - t_atom_pref = tf.placeholder(shape=[None, None], dtype=tf.float64) - find_energy = tf.constant(1.0, dtype=tf.float64) - find_force = tf.constant(1.0, dtype=tf.float64) - find_virial = tf.constant(0.0, dtype=tf.float64) - find_atom_energy = tf.constant(0.0, dtype=tf.float64) - find_atom_pref = tf.constant(0.0, dtype=tf.float64) - model_dict = { - "energy": t_penergy, - "force": t_pforce, - "virial": t_pvirial, - "atom_ener": t_patom_energy, - } - label_dict = { - "energy": t_lenergy, - "force": t_lforce, - "virial": t_lvirial, - "atom_ener": t_latom_energy, - "atom_pref": t_atom_pref, - "find_energy": find_energy, - "find_force": find_force, - "find_virial": find_virial, - "find_atom_ener": find_atom_energy, - "find_atom_pref": find_atom_pref, - } - self.base_loss_sess = base.build( - t_cur_lr, t_natoms, model_dict, label_dict, "" - ) - # torch - self.feed_dict = { - t_cur_lr: self.cur_lr, - t_natoms: natoms_tf, - t_penergy: p_energy, - t_pforce: p_force_merge_tf, - t_pvirial: p_virial.reshape(-1, 9), - t_patom_energy: atom_energy, - t_lenergy: l_energy, - t_lforce: l_force_merge_tf, - t_lvirial: l_virial.reshape(-1, 9), - t_latom_energy: atom_energy, - t_atom_pref: atom_pref, - } - self.model_pred = { - "energy": torch.from_numpy(p_energy), - "force": torch.from_numpy(p_force_real).reshape(nframes, self.nloc, 3), - "force_mag": torch.from_numpy(p_force_mag).reshape(nframes, self.nloc, 3), - "mask_mag": torch.from_numpy(np_batch["atype"] == 0).reshape( - nframes, self.nloc, 1 - ), - } - self.label = { - "energy": torch.from_numpy(l_energy), - "find_energy": 1.0, - "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), - "find_force": 1.0, - "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), - "find_force_mag": 1.0, - } - self.label_absent = { - "energy": torch.from_numpy(l_energy), - "force": torch.from_numpy(l_force_real).reshape(nframes, self.nloc, 3), - "force_mag": torch.from_numpy(l_force_mag).reshape(nframes, self.nloc, 3), - } - self.natoms = pt_batch["natoms"] - - def tearDown(self) -> None: - tf.reset_default_graph() - return super().tearDown() + # pt + self.pt_loss = EnergySpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, + start_pref_ae=self.start_pref_ae, + limit_pref_ae=self.limit_pref_ae, + ) + self.spin = True + super().setUp() def test_consistency(self): with tf.Session(graph=self.g) as sess: - base_loss, base_more_loss = sess.run( - self.base_loss_sess, feed_dict=self.feed_dict + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict ) - mine = EnergySpinLoss( + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc_tf, # use tf natoms pref + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc_tf, # use tf natoms pref + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["ener", "force_r", "force_m", "atom_ener"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] + ) + ) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) + + +class TestEnerSpinLossAecoeff(LossCommonTest): + def setUp(self): + self.start_lr = 1.1 + self.start_pref_e = 0.02 + self.limit_pref_e = 1.0 + self.start_pref_fr = 1000.0 + self.limit_pref_fr = 1.0 + self.start_pref_fm = 1000.0 + self.limit_pref_fm = 1.0 + self.cur_lr = 1.2 + self.use_spin = [1, 0] + # tf + self.tf_loss = EnerSpinLoss( self.start_lr, self.start_pref_e, self.limit_pref_e, @@ -375,36 +742,57 @@ def test_consistency(self): self.limit_pref_fr, self.start_pref_fm, self.limit_pref_fm, + use_spin=self.use_spin, + enable_atom_ener_coeff=True, ) + # pt + self.pt_loss = EnergySpinLoss( + self.start_lr, + self.start_pref_e, + self.limit_pref_e, + self.start_pref_fr, + self.limit_pref_fr, + self.start_pref_fm, + self.limit_pref_fm, + enable_atom_ener_coeff=True, + ) + self.spin = True + super().setUp() + + def test_consistency(self): + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) def fake_model(): return self.model_pred - _, my_loss, my_more_loss = mine( + _, pt_loss, pt_more_loss = self.pt_loss( {}, fake_model, self.label, self.nloc_tf, # use tf natoms pref self.cur_lr, ) - _, my_loss_absent, my_more_loss_absent = mine( + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( {}, fake_model, self.label_absent, self.nloc_tf, # use tf natoms pref self.cur_lr, ) - my_loss = my_loss.detach().cpu() - my_loss_absent = my_loss_absent.detach().cpu() - self.assertTrue(np.allclose(base_loss, my_loss.numpy())) - self.assertTrue(np.allclose(0.0, my_loss_absent.numpy())) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) for key in ["ener", "force_r", "force_m"]: self.assertTrue( np.allclose( - base_more_loss[f"l2_{key}_loss"], my_more_loss[f"l2_{key}_loss"] + tf_more_loss[f"l2_{key}_loss"], pt_more_loss[f"l2_{key}_loss"] ) ) - self.assertTrue(np.isnan(my_more_loss_absent[f"l2_{key}_loss"])) + self.assertTrue(np.isnan(pt_more_loss_absent[f"l2_{key}_loss"])) if __name__ == "__main__":