From 9d18dc4f043e9af82dcbef748de6627a95737928 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 29 Feb 2024 18:54:44 +0800 Subject: [PATCH] make label_requirement dynamic --- deepmd/pt/loss/ener.py | 135 +++++++++++++++++++++++++++++------------ 1 file changed, 95 insertions(+), 40 deletions(-) diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 648e954401..2834733112 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -30,16 +30,57 @@ def __init__( limit_pref_f=0.0, start_pref_v=0.0, limit_pref_v=0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, use_l1_all: bool = False, inference=False, **kwargs, ): - """Construct a layer to compute loss on energy, force and virial.""" + r"""Construct a layer to compute loss on energy, force and virial. + + Parameters + ---------- + starter_learning_rate : float + The learning rate at the start of the training. + start_pref_e : float + The prefactor of energy loss at the start of the training. + limit_pref_e : float + The prefactor of energy loss at the end of the training. + start_pref_f : float + The prefactor of force loss at the start of the training. + limit_pref_f : float + The prefactor of force loss at the end of the training. + start_pref_v : float + The prefactor of virial loss at the start of the training. + limit_pref_v : float + The prefactor of virial loss at the end of the training. + start_pref_ae : float + The prefactor of atomic energy loss at the start of the training. + limit_pref_ae : float + The prefactor of atomic energy loss at the end of the training. + start_pref_pf : float + The prefactor of atomic prefactor force loss at the start of the training. + limit_pref_pf : float + The prefactor of atomic prefactor force loss at the end of the training. + use_l1_all : bool + Whether to use L1 loss, if False (default), it will use L2 loss. + inference : bool + If true, it will output all losses found in output, ignoring the pre-factors. + **kwargs + Other keyword arguments. + """ super().__init__() self.starter_learning_rate = starter_learning_rate self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference self.has_f = (start_pref_f != 0.0 and limit_pref_f != 0.0) or inference self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference + + # TODO need support for atomic energy and atomic pref + self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference + self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference + self.start_pref_e = start_pref_e self.limit_pref_e = limit_pref_e self.start_pref_f = start_pref_f @@ -164,42 +205,56 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False): @property def label_requirement(self) -> List[DataRequirementItem]: """Return data label requirements needed for this loss calculation.""" - data_requirement = [ - DataRequirementItem( - "energy", - ndof=1, - atomic=False, - must=False, - high_prec=True, - ), - DataRequirementItem( - "force", - ndof=3, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "virial", - ndof=9, - atomic=False, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_ener", - ndof=1, - atomic=True, - must=False, - high_prec=False, - ), - DataRequirementItem( - "atom_pref", - ndof=1, - atomic=True, - must=False, - high_prec=False, - repeat=3, - ), - ] - return data_requirement + label_requirement = [] + if self.has_e: + label_requirement.append( + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ) + ) + if self.has_f: + label_requirement.append( + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_v: + label_requirement.append( + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ) + ) + if self.has_ae: + label_requirement.append( + DataRequirementItem( + "atom_ener", + ndof=1, + atomic=True, + must=False, + high_prec=False, + ) + ) + if self.has_pf: + label_requirement.append( + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ) + ) + return label_requirement