diff --git a/deepmd/pt/loss/tensor.py b/deepmd/pt/loss/tensor.py index 8f2f937a07..69b133de58 100644 --- a/deepmd/pt/loss/tensor.py +++ b/deepmd/pt/loss/tensor.py @@ -22,6 +22,7 @@ def __init__( pref_atomic: float = 0.0, pref: float = 0.0, inference=False, + enable_atomic_weight: bool = False, **kwargs, ) -> None: r"""Construct a loss for local and global tensors. @@ -40,6 +41,8 @@ def __init__( The prefactor of the weight of global loss. It should be larger than or equal to 0. inference : bool If true, it will output all losses found in output, ignoring the pre-factors. + enable_atomic_weight : bool + If true, atomic weight will be used in the loss calculation. **kwargs Other keyword arguments. """ @@ -50,6 +53,7 @@ def __init__( self.local_weight = pref_atomic self.global_weight = pref self.inference = inference + self.enable_atomic_weight = enable_atomic_weight assert ( self.local_weight >= 0.0 and self.global_weight >= 0.0 @@ -85,6 +89,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False """ model_pred = model(**input_dict) del learning_rate, mae + + if self.enable_atomic_weight: + atomic_weight = label["atom_weight"].reshape([-1, 1]) + else: + atomic_weight = 1.0 + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] more_loss = {} if ( @@ -103,6 +113,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False diff = (local_tensor_pred - local_tensor_label).reshape( [-1, self.tensor_size] ) + diff = diff * atomic_weight if "mask" in model_pred: diff = diff[model_pred["mask"].reshape([-1]).bool()] l2_local_loss = torch.mean(torch.square(diff)) @@ -171,4 +182,15 @@ def label_requirement(self) -> list[DataRequirementItem]: high_prec=False, ) ) + if self.enable_atomic_weight: + label_requirement.append( + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ) + ) return label_requirement diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index aca9182ff6..d7f879b4b4 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -40,6 +40,7 @@ def __init__(self, jdata, **kwarg) -> None: # YWolfeee: modify, use pref / pref_atomic, instead of pref_weight / pref_atomic_weight self.local_weight = jdata.get("pref_atomic", None) self.global_weight = jdata.get("pref", None) + self.enable_atomic_weight = jdata.get("enable_atomic_weight", False) assert ( self.local_weight is not None and self.global_weight is not None @@ -66,9 +67,18 @@ def build(self, learning_rate, natoms, model_dict, label_dict, suffix): "global_loss": global_cvt_2_tf_float(0.0), } + if self.enable_atomic_weight: + atomic_weight = tf.reshape(label_dict["atom_weight"], [-1, 1]) + else: + atomic_weight = global_cvt_2_tf_float(1.0) + if self.local_weight > 0.0: + diff = tf.reshape(polar, [-1, self.tensor_size]) - tf.reshape( + atomic_polar_hat, [-1, self.tensor_size] + ) + diff = diff * atomic_weight local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( - tf.square(self.scale * (polar - atomic_polar_hat)), name="l2_" + suffix + tf.square(self.scale * diff), name="l2_" + suffix ) more_loss["local_loss"] = self.display_if_exist(local_loss, find_atomic) l2_loss += self.local_weight * local_loss @@ -163,4 +173,16 @@ def label_requirement(self) -> list[DataRequirementItem]: type_sel=self.type_sel, ) ) + if self.enable_atomic_weight: + data_requirements.append( + DataRequirementItem( + "atom_weight", + 1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + type_sel=self.type_sel, + ) + ) return data_requirements diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7f107ce64a..50ef07b2af 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2514,8 +2514,9 @@ def loss_property(): def loss_tensor(): # doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]." # doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well." - doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." - doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #selected atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to global label, i.e. 'polarizability.npy` or `dipole.npy`, whose shape should be #frames x [9 or 3]. If it's larger than 0.0, this npy should be included." + doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. It controls the weight of loss corresponding to atomic label, i.e. `atomic_polarizability.npy` or `atomic_dipole.npy`, whose shape should be #frames x ([9 or 3] x #atoms). If it's larger than 0.0, this npy should be included. Both `pref` and `pref_atomic` should be provided, and either can be set to 0.0." + doc_enable_atomic_weight = "If true, the atomic loss will be reweighted." return [ Argument( "pref", [float, int], optional=False, default=None, doc=doc_global_weight @@ -2527,6 +2528,13 @@ def loss_tensor(): default=None, doc=doc_local_weight, ), + Argument( + "enable_atomic_weight", + bool, + optional=True, + default=False, + doc=doc_enable_atomic_weight, + ), ] diff --git a/source/tests/pt/test_loss_tensor.py b/source/tests/pt/test_loss_tensor.py new file mode 100644 index 0000000000..5802c0b775 --- /dev/null +++ b/source/tests/pt/test_loss_tensor.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import unittest + +import numpy as np +import tensorflow.compat.v1 as tf +import torch + +tf.disable_eager_execution() +from pathlib import ( + Path, +) + +from deepmd.pt.loss import TensorLoss as PTTensorLoss +from deepmd.pt.utils import ( + dp_random, + env, +) +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.tf.loss.tensor import TensorLoss as TFTensorLoss +from deepmd.utils.data import ( + DataRequirementItem, +) + +from ..seed import ( + GLOBAL_SEED, +) + +CUR_DIR = os.path.dirname(__file__) + + +def get_batch(system, type_map, data_requirement): + dataset = DeepmdDataSetForLoader(system, type_map) + dataset.add_data_requirement(data_requirement) + np_batch, pt_batch = get_single_batch(dataset) + return np_batch, pt_batch + + +def get_single_batch(dataset, index=None): + if index is None: + index = dp_random.choice(np.arange(len(dataset))) + np_batch = dataset[index] + pt_batch = {} + + for key in [ + "coord", + "box", + "atom_dipole", + "dipole", + "atom_polarizability", + "polarizability", + "atype", + "natoms", + ]: + if key in np_batch.keys(): + np_batch[key] = np.expand_dims(np_batch[key], axis=0) + pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE) + if key in ["coord", "atom_dipole"]: + np_batch[key] = np_batch[key].reshape(1, -1) + np_batch["natoms"] = np_batch["natoms"][0] + return np_batch, pt_batch + + +class LossCommonTest(unittest.TestCase): + def setUp(self) -> None: + self.cur_lr = 1.2 + self.type_map = ["H", "O"] + + # data + tensor_data_requirement = [ + DataRequirementItem( + "atomic_" + self.label_name, + ndof=self.tensor_size, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + self.label_name, + ndof=self.tensor_size, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "atomic_weight", + ndof=1, + atomic=True, + must=False, + high_prec=False, + default=1.0, + ), + ] + np_batch, pt_batch = get_batch( + self.system, self.type_map, tensor_data_requirement + ) + natoms = np_batch["natoms"] + self.nloc = natoms[0] + self.nframes = np_batch["atom_" + self.label_name].shape[0] + rng = np.random.default_rng(GLOBAL_SEED) + + l_atomic_tensor, l_global_tensor = ( + np_batch["atom_" + self.label_name], + np_batch[self.label_name], + ) + p_atomic_tensor, p_global_tensor = ( + np.ones_like(l_atomic_tensor), + np.ones_like(l_global_tensor), + ) + + batch_size = pt_batch["coord"].shape[0] + + # atom_pref = rng.random(size=[batch_size, nloc * 3]) + # drdq = rng.random(size=[batch_size, nloc * 2 * 3]) + atom_weight = rng.random(size=[batch_size, self.nloc]) + + # tf + self.g = tf.Graph() + with self.g.as_default(): + t_cur_lr = tf.placeholder(shape=[], dtype=tf.float64) + t_natoms = tf.placeholder(shape=[None], dtype=tf.int32) + t_patomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_pglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_latomic_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_lglobal_tensor = tf.placeholder(shape=[None, None], dtype=tf.float64) + t_atom_weight = tf.placeholder(shape=[None, None], dtype=tf.float64) + find_atomic = tf.constant(1.0, dtype=tf.float64) + find_global = tf.constant(1.0, dtype=tf.float64) + find_atom_weight = tf.constant(1.0, dtype=tf.float64) + model_dict = { + self.tensor_name: t_patomic_tensor, + } + label_dict = { + "atom_" + self.label_name: t_latomic_tensor, + "find_atom_" + self.label_name: find_atomic, + self.label_name: t_lglobal_tensor, + "find_" + self.label_name: find_global, + "atom_weight": t_atom_weight, + "find_atom_weight": find_atom_weight, + } + self.tf_loss_sess = self.tf_loss.build( + t_cur_lr, t_natoms, model_dict, label_dict, "" + ) + + self.feed_dict = { + t_cur_lr: self.cur_lr, + t_natoms: natoms, + t_patomic_tensor: p_atomic_tensor, + t_pglobal_tensor: p_global_tensor, + t_latomic_tensor: l_atomic_tensor, + t_lglobal_tensor: l_global_tensor, + t_atom_weight: atom_weight, + } + # pt + self.model_pred = { + self.tensor_name: torch.from_numpy(p_atomic_tensor), + "global_" + self.tensor_name: torch.from_numpy(p_global_tensor), + } + self.label = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + "find_" + "atom_" + self.label_name: 1.0, + self.label_name: torch.from_numpy(l_global_tensor), + "find_" + self.label_name: 1.0, + "atom_weight": torch.from_numpy(atom_weight), + "find_atom_weight": 1.0, + } + self.label_absent = { + "atom_" + self.label_name: torch.from_numpy(l_atomic_tensor), + self.label_name: torch.from_numpy(l_global_tensor), + "atom_weight": torch.from_numpy(atom_weight), + } + self.natoms = pt_batch["natoms"] + + def tearDown(self) -> None: + tf.reset_default_graph() + return super().tearDown() + + +class TestAtomicDipoleLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicDipoleAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "dipole" + self.tensor_size = 3 + self.label_name = "dipole" + self.system = str(Path(__file__).parent / "water_tensor/dipole/O78H156") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +class TestAtomicPolarAWeightLoss(LossCommonTest): + def setUp(self) -> None: + self.tensor_name = "polar" + self.tensor_size = 9 + self.label_name = "polarizability" + + self.system = str(Path(__file__).parent / "water_tensor/polar/atomic_system") + + self.pref_atomic = 1.0 + self.pref = 0.0 + # tf + self.tf_loss = TFTensorLoss( + { + "pref_atomic": self.pref_atomic, + "pref": self.pref, + "enable_atomic_weight": True, + }, + tensor_name=self.tensor_name, + tensor_size=self.tensor_size, + label_name=self.label_name, + ) + # pt + self.pt_loss = PTTensorLoss( + self.tensor_name, + self.tensor_size, + self.label_name, + self.pref_atomic, + self.pref, + enable_atomic_weight=True, + ) + + super().setUp() + + def test_consistency(self) -> None: + with tf.Session(graph=self.g) as sess: + tf_loss, tf_more_loss = sess.run( + self.tf_loss_sess, feed_dict=self.feed_dict + ) + + def fake_model(): + return self.model_pred + + _, pt_loss, pt_more_loss = self.pt_loss( + {}, + fake_model, + self.label, + self.nloc, + self.cur_lr, + ) + _, pt_loss_absent, pt_more_loss_absent = self.pt_loss( + {}, + fake_model, + self.label_absent, + self.nloc, + self.cur_lr, + ) + pt_loss = pt_loss.detach().cpu() + pt_loss_absent = pt_loss_absent.detach().cpu() + self.assertTrue(np.allclose(tf_loss, pt_loss.numpy())) + self.assertTrue(np.allclose(0.0, pt_loss_absent.numpy())) + for key in ["local"]: + self.assertTrue( + np.allclose( + tf_more_loss[f"{key}_loss"], + pt_more_loss[f"l2_{key}_{self.tensor_name}_loss"], + ) + ) + self.assertTrue( + np.isnan(pt_more_loss_absent[f"l2_{key}_{self.tensor_name}_loss"]) + ) + + +if __name__ == "__main__": + unittest.main()