Skip to content

Commit

Permalink
fix: cuda tests of linear and pair atomic model (#3248)
Browse files Browse the repository at this point in the history
This PR is to fix LinearAtomicModel GPU compatibility.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Feb 9, 2024
1 parent c235099 commit 9181a02
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 10 additions & 10 deletions deepmd/pt/model/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
)

import torch
from torch import (
nn,
)

from deepmd.dpmodel import (
FittingOutputDef,
Expand All @@ -22,9 +19,12 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .model import (
BaseModel,
)


class PairTabModel(nn.Module, BaseAtomicModel):
class PairTabModel(BaseModel, BaseAtomicModel):
"""Pairwise tabulation energy model.
This model can be used to tabulate the pairwise energy between atoms for either
Expand Down Expand Up @@ -62,11 +62,11 @@ def __init__(
tab_info,
tab_data,
) = self.tab.get() # this returns -> Tuple[np.array, np.array]
self.tab_info = torch.from_numpy(tab_info)
self.tab_data = torch.from_numpy(tab_data)
self.register_buffer("tab_info", torch.from_numpy(tab_info))
self.register_buffer("tab_data", torch.from_numpy(tab_data))
else:
self.tab_info = None
self.tab_data = None
self.register_buffer("tab_info", None)
self.register_buffer("tab_data", None)

# self.model_type = "ener"
# self.model_version = MODEL_VERSION ## this shoud be in the parent class
Expand Down Expand Up @@ -118,8 +118,8 @@ def deserialize(cls, data) -> "PairTabModel":
tab = PairTab.deserialize(data["tab"])
tab_model = cls(None, rcut, sel)
tab_model.tab = tab
tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info)
tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data)
tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info))
tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data))
return tab_model

def forward_atomic(
Expand Down
10 changes: 5 additions & 5 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_pairwise(self, mock_loadtxt):
[0.25, 0.0, 0.0, 0.0],
]
)
extended_atype = torch.tensor([[0, 0]])
nlist = torch.tensor([[[1], [-1]]])
extended_atype = torch.tensor([[0, 0]]).to(env.DEVICE)
nlist = torch.tensor([[[1], [-1]]]).to(env.DEVICE)

ds = DescrptSeA(
rcut=0.3,
Expand All @@ -82,7 +82,7 @@ def test_pairwise(self, mock_loadtxt):
zbl_model,
sw_rmin=0.1,
sw_rmax=0.25,
)
).to(env.DEVICE)
wgt_res = []
for dist in np.linspace(0.05, 0.3, 10):
extended_coord = torch.tensor(
Expand All @@ -92,7 +92,7 @@ def test_pairwise(self, mock_loadtxt):
[0.0, dist, 0.0],
],
]
)
).to(env.DEVICE)

wgt_model.forward_atomic(extended_coord, extended_atype, nlist)

Expand All @@ -112,7 +112,7 @@ def test_pairwise(self, mock_loadtxt):
[0.0, 0.0],
],
dtype=torch.float64,
)
).to(env.DEVICE)
torch.testing.assert_close(results, excepted_res, rtol=0.0001, atol=0.0001)


Expand Down

0 comments on commit 9181a02

Please sign in to comment.