Skip to content

Commit

Permalink
input order of env_mat changed to be consistent with descriptor (#3125)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Jan 10, 2024
1 parent 438bc78 commit dac64cf
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion deepmd_utils/model_format/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(

def call(
self,
nlist: np.ndarray,
coord_ext: np.ndarray,
atype_ext: np.ndarray,
nlist: np.ndarray,
davg: Optional[np.ndarray] = None,
dstd: Optional[np.ndarray] = None,
) -> Union[np.ndarray, np.ndarray]:
Expand Down
2 changes: 1 addition & 1 deletion deepmd_utils/model_format/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def call(
The descriptor. shape: nf x nloc x ng x axis_neuron
"""
# nf x nloc x nnei x 4
rr, ww = self.env_mat.call(nlist, coord_ext, atype_ext, self.davg, self.dstd)
rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))

Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def test_self_consistency(
dstd = 0.1 + np.abs(dstd)
em0 = EnvMat(self.rcut, self.rcut_smth)
em1 = EnvMat.deserialize(em0.serialize())
mm0, ww0 = em0.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd)
mm1, ww1 = em1.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd)
mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd)
mm1, ww1 = em1.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd)
np.testing.assert_allclose(mm0, mm1)
np.testing.assert_allclose(ww0, ww1)

Expand Down

0 comments on commit dac64cf

Please sign in to comment.