From 8e03962e1cc4fcf7ad54aae5737822a82f338977 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 10 Jan 2024 10:50:57 +0800 Subject: [PATCH] input order of env_mat changed to be consistent with descriptor --- deepmd_utils/model_format/env_mat.py | 2 +- deepmd_utils/model_format/se_e2_a.py | 2 +- source/tests/test_model_format_utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd_utils/model_format/env_mat.py b/deepmd_utils/model_format/env_mat.py index 84771135a6..7822bd7d0c 100644 --- a/deepmd_utils/model_format/env_mat.py +++ b/deepmd_utils/model_format/env_mat.py @@ -68,9 +68,9 @@ def __init__( def call( self, - nlist: np.ndarray, coord_ext: np.ndarray, atype_ext: np.ndarray, + nlist: np.ndarray, davg: Optional[np.ndarray] = None, dstd: Optional[np.ndarray] = None, ) -> Union[np.ndarray, np.ndarray]: diff --git a/deepmd_utils/model_format/se_e2_a.py b/deepmd_utils/model_format/se_e2_a.py index 114f9df915..5a4fe15a2d 100644 --- a/deepmd_utils/model_format/se_e2_a.py +++ b/deepmd_utils/model_format/se_e2_a.py @@ -137,7 +137,7 @@ def call( The descriptor. shape: nf x nloc x ng x axis_neuron """ # nf x nloc x nnei x 4 - rr, ww = self.env_mat.call(nlist, coord_ext, atype_ext, self.davg, self.dstd) + rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) nf, nloc, nnei, _ = rr.shape sec = np.append([0], np.cumsum(self.sel)) diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index 7fd93c1366..ac37120e83 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -192,8 +192,8 @@ def test_self_consistency( dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) em1 = EnvMat.deserialize(em0.serialize()) - mm0, ww0 = em0.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd) - mm1, ww1 = em1.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd) + mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) + mm1, ww1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) np.testing.assert_allclose(mm0, mm1) np.testing.assert_allclose(ww0, ww1)