diff --git a/deepmd/pt/model/model/pairtab_atomic_model.py b/deepmd/pt/model/model/pairtab_atomic_model.py index 98215191c1..2837aaffe7 100644 --- a/deepmd/pt/model/model/pairtab_atomic_model.py +++ b/deepmd/pt/model/model/pairtab_atomic_model.py @@ -7,9 +7,6 @@ ) import torch -from torch import ( - nn, -) from deepmd.dpmodel import ( FittingOutputDef, @@ -22,9 +19,12 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .model import ( + BaseModel, +) -class PairTabModel(nn.Module, BaseAtomicModel): +class PairTabModel(BaseModel, BaseAtomicModel): """Pairwise tabulation energy model. This model can be used to tabulate the pairwise energy between atoms for either @@ -62,11 +62,11 @@ def __init__( tab_info, tab_data, ) = self.tab.get() # this returns -> Tuple[np.array, np.array] - self.tab_info = torch.from_numpy(tab_info) - self.tab_data = torch.from_numpy(tab_data) + self.register_buffer("tab_info", torch.from_numpy(tab_info)) + self.register_buffer("tab_data", torch.from_numpy(tab_data)) else: - self.tab_info = None - self.tab_data = None + self.register_buffer("tab_info", None) + self.register_buffer("tab_data", None) # self.model_type = "ener" # self.model_version = MODEL_VERSION ## this shoud be in the parent class @@ -118,8 +118,8 @@ def deserialize(cls, data) -> "PairTabModel": tab = PairTab.deserialize(data["tab"]) tab_model = cls(None, rcut, sel) tab_model.tab = tab - tab_model.tab_info = torch.from_numpy(tab_model.tab.tab_info) - tab_model.tab_data = torch.from_numpy(tab_model.tab.tab_data) + tab_model.register_buffer("tab_info", torch.from_numpy(tab_model.tab.tab_info)) + tab_model.register_buffer("tab_data", torch.from_numpy(tab_model.tab.tab_data)) return tab_model def forward_atomic( diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index 211b1f8215..e9090de86a 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -56,8 +56,8 @@ def test_pairwise(self, mock_loadtxt): [0.25, 0.0, 0.0, 0.0], ] ) - extended_atype = torch.tensor([[0, 0]]) - nlist = torch.tensor([[[1], [-1]]]) + extended_atype = torch.tensor([[0, 0]]).to(env.DEVICE) + nlist = torch.tensor([[[1], [-1]]]).to(env.DEVICE) ds = DescrptSeA( rcut=0.3, @@ -82,7 +82,7 @@ def test_pairwise(self, mock_loadtxt): zbl_model, sw_rmin=0.1, sw_rmax=0.25, - ) + ).to(env.DEVICE) wgt_res = [] for dist in np.linspace(0.05, 0.3, 10): extended_coord = torch.tensor( @@ -92,7 +92,7 @@ def test_pairwise(self, mock_loadtxt): [0.0, dist, 0.0], ], ] - ) + ).to(env.DEVICE) wgt_model.forward_atomic(extended_coord, extended_atype, nlist) @@ -112,7 +112,7 @@ def test_pairwise(self, mock_loadtxt): [0.0, 0.0], ], dtype=torch.float64, - ) + ).to(env.DEVICE) torch.testing.assert_close(results, excepted_res, rtol=0.0001, atol=0.0001)