Skip to content

Commit

Permalink
feat: add serialize and deserialize to pt pairtab
Browse files Browse the repository at this point in the history
  • Loading branch information
Anyang Peng authored and Anyang Peng committed Feb 1, 2024
1 parent fb4ae7d commit e423e68
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
14 changes: 14 additions & 0 deletions deepmd/pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def get_sel(self) -> int:
def distinguish_types(self) -> bool:
# to match DPA1 and DPA2.
return False

def serialize(self) -> dict:
return {
"tab_file": self.tab_file,
"rcut": self.rcut,
"sel": self.sel
}

@classmethod
def deserialize(cls, data) -> "PairTabModel":
tab_file = data["tab_file"]
rcut = data["rcut"]
sel = data["sel"]
return cls(tab_file, rcut, sel)

def forward_atomic(
self,
Expand Down
14 changes: 14 additions & 0 deletions source/tests/pt/test_pairtab.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def test_with_mask(self):

def test_jit(self):
model = torch.jit.script(self.model)

@patch("numpy.loadtxt")
def test_deserialize(self, mock_loadtxt):
file_path = "dummy_path"

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable file_path is not used.
mock_loadtxt.return_value = np.array(
[
[0.005, 1.0, 2.0, 3.0],
[0.01, 0.8, 1.6, 2.4],
[0.015, 0.5, 1.0, 1.5],
[0.02, 0.25, 0.4, 0.75],
]
)
model1 = PairTabModel.deserialize(self.model.serialize())
model1 = torch.jit.script(model1)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable model1 is not used.


class TestPairTabTwoAtoms(unittest.TestCase):
Expand Down

0 comments on commit e423e68

Please sign in to comment.