Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 28, 2024
1 parent e8ff097 commit c0f2c56
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions source/tests/pt/model/test_descriptor_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.pt.utils.env import (
PRECISION_DICT,
)
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)

from .test_env_mat import (
TestCaseSingleFrameWithNlist,
Expand All @@ -23,7 +26,6 @@
get_tols,
)

from deepmd.pt.utils.env_mat_stat import EnvMatStatSe
dtype = env.GLOBAL_PT_FLOAT_PRECISION


Expand Down Expand Up @@ -103,7 +105,7 @@ def test_consistency(
atol=atol,
err_msg=err_msg,
)

def test_load_stat(self):
rng = np.random.default_rng()
_, _, nnei = self.nlist.shape
Expand All @@ -129,17 +131,29 @@ def test_load_stat(self):
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
dd1 = DescrptSeR.deserialize(dd0.serialize())
dd1.compute_input_stats([{'r0':None, 'coord':torch.from_numpy(self.coord_ext).reshape(-1, self.nall, 3), "atype": torch.from_numpy(self.atype_ext),"box":None, "natoms":self.nall}])

dd1.compute_input_stats(
[
{
"r0": None,
"coord": torch.from_numpy(self.coord_ext).reshape(
-1, self.nall, 3
),
"atype": torch.from_numpy(self.atype_ext),
"box": None,
"natoms": self.nall,
}
]
)

with self.assertRaises(ValueError) as cm:
ev = EnvMatStatSe(dd1)
ev.last_dim =3
ev.last_dim = 3
ev.load_or_compute_stats([])
self.assertEqual(
'last_dim should be 1 for raial-only or 4 for full descriptor.',
str(cm.exception)
)
"last_dim should be 1 for raial-only or 4 for full descriptor.",
str(cm.exception),
)

def test_jit(
self,
):
Expand Down

0 comments on commit c0f2c56

Please sign in to comment.