From df41bd0998ddeb7c315f1e3c9ac8bcb7f021b82f Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:45:56 +0800 Subject: [PATCH] Fix: se_r prod_env_mat (#3351) This should fix the bug. image --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/utils/env_mat_stat.py | 9 +++ source/tests/pt/model/test_descriptor_se_r.py | 55 ++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 3af03bda97..70b7228440 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -101,6 +101,14 @@ def iter( dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) + if self.last_dim == 4: + radial_only = False + elif self.last_dim == 1: + radial_only = True + else: + raise ValueError( + "last_dim should be 1 for raial-only or 4 for full descriptor." + ) for system in data: coord, atype, box, natoms = ( system["coord"], @@ -130,6 +138,7 @@ def iter( self.descriptor.get_rcut(), # TODO: export rcut_smth from DescriptorBlock self.descriptor.rcut_smth, + radial_only, ) # reshape to nframes * nloc at the atom level, # so nframes/mixed_type do not matter diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py index c999f06863..32270e263b 100644 --- a/source/tests/pt/model/test_descriptor_se_r.py +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -15,6 +15,9 @@ from deepmd.pt.utils.env import ( PRECISION_DICT, ) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) from .test_env_mat import ( TestCaseSingleFrameWithNlist, @@ -103,13 +106,61 @@ def test_consistency( err_msg=err_msg, ) + def test_load_stat(self): + rng = np.random.default_rng() + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + dd1.compute_input_stats( + [ + { + "r0": None, + "coord": torch.from_numpy(self.coord_ext).reshape( + -1, self.nall, 3 + ), + "atype": torch.from_numpy(self.atype_ext), + "box": None, + "natoms": self.nall, + } + ] + ) + + with self.assertRaises(ValueError) as cm: + ev = EnvMatStatSe(dd1) + ev.last_dim = 3 + ev.load_or_compute_stats([]) + self.assertEqual( + "last_dim should be 1 for raial-only or 4 for full descriptor.", + str(cm.exception), + ) + def test_jit( self, ): rng = np.random.default_rng() _, _, nnei = self.nlist.shape - davg = rng.normal(size=(self.nt, nnei, 4)) - dstd = rng.normal(size=(self.nt, nnei, 4)) + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) dstd = 0.1 + np.abs(dstd) for idt, prec in itertools.product(