Skip to content

Commit

Permalink
Solve alert
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Jun 4, 2024
1 parent d25993f commit 001d21c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
data.pop("env_mat")
embeddings_strip = data.pop("embeddings_strip")
obj = cls(**data)

Expand Down
8 changes: 2 additions & 6 deletions source/tests/pt/model/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_jit(
self,
):
rng = np.random.default_rng()
nf, nloc, nnei = self.nlist.shape
_, _, nnei = self.nlist.shape
davg = rng.normal(size=(self.nt, nnei, 4))
dstd = rng.normal(size=(self.nt, nnei, 4))
dstd = 0.1 + np.abs(dstd)
Expand All @@ -123,8 +123,6 @@ def test_jit(
[False, True], # use_econf_tebd
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
err_msg = f"idt={idt} prec={prec}"
# dpa1 new impl
dd0 = DescrptSeAttenV2(
self.rcut,
Expand All @@ -140,6 +138,4 @@ def test_jit(
)
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
# dd1 = DescrptDPA1.deserialize(dd0.serialize())
model = torch.jit.script(dd0)
# model = torch.jit.script(dd1)
_ = torch.jit.script(dd0)

0 comments on commit 001d21c

Please sign in to comment.