Skip to content

Commit

Permalink
feat: jit export in linear model
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 21, 2024
1 parent 5aa1b89 commit fdb6a3d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
16 changes: 13 additions & 3 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,31 @@ def mixed_types(self) -> bool:
"""
return True

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())


@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map


@torch.jit.export
def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]


@torch.jit.export
def get_sel(self) -> List[int]:
return [max([model.get_nsel() for model in self.models])]

@torch.jit.export
def get_model_nsels(self) -> List[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]

@torch.jit.export
def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]
Expand Down Expand Up @@ -289,15 +295,18 @@ def _compute_weight(
for _ in range(nmodels)
]

@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# tricky...
return max([model.get_dim_fparam() for model in self.models])

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return max([model.get_dim_aparam() for model in self.models])

@torch.jit.export
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Expand All @@ -318,6 +327,7 @@ def get_sel_type(self) -> List[int]:
)
).tolist()

@torch.jit.export
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
10 changes: 4 additions & 6 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def test_self_consistency(self):

def test_jit(self):
md1 = torch.jit.script(self.md1)
# atomic model no more export methods
# self.assertEqual(md1.get_rcut(), self.rcut)
# self.assertEqual(md1.get_type_map(), ["foo", "bar"])
self.assertEqual(md1.get_rcut(), self.rcut)
self.assertEqual(md1.get_type_map(), ["foo", "bar"])
md3 = torch.jit.script(self.md3)
# atomic model no more export methods
# self.assertEqual(md3.get_rcut(), self.rcut)
# self.assertEqual(md3.get_type_map(), ["foo", "bar"])
self.assertEqual(md3.get_rcut(), self.rcut)
self.assertEqual(md3.get_type_map(), ["foo", "bar"])


class TestRemmapMethod(unittest.TestCase):
Expand Down

0 comments on commit fdb6a3d

Please sign in to comment.