From fdb6a3d67f55755150212b849b95fc49cbce80f4 Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Thu, 21 Mar 2024 16:17:46 +0800 Subject: [PATCH] feat: jit export in linear model --- .../pt/model/atomic_model/linear_atomic_model.py | 16 +++++++++++++--- .../tests/pt/model/test_linear_atomic_model.py | 10 ++++------ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index f7216f46ef..2cf1cc556c 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -92,25 +92,31 @@ def mixed_types(self) -> bool: """ return True + @torch.jit.export def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) - + + @torch.jit.export def get_type_map(self) -> List[str]: """Get the type map.""" return self.type_map - + + @torch.jit.export def get_model_rcuts(self) -> List[float]: """Get the cut-off radius for each individual models.""" return [model.get_rcut() for model in self.models] - + + @torch.jit.export def get_sel(self) -> List[int]: return [max([model.get_nsel() for model in self.models])] + @torch.jit.export def get_model_nsels(self) -> List[int]: """Get the processed sels for each individual models. Not distinguishing types.""" return [model.get_nsel() for model in self.models] + @torch.jit.export def get_model_sels(self) -> List[List[int]]: """Get the sels for each individual models.""" return [model.get_sel() for model in self.models] @@ -289,15 +295,18 @@ def _compute_weight( for _ in range(nmodels) ] + @torch.jit.export def get_dim_fparam(self) -> int: """Get the number (dimension) of frame parameters of this atomic model.""" # tricky... return max([model.get_dim_fparam() for model in self.models]) + @torch.jit.export def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return max([model.get_dim_aparam() for model in self.models]) + @torch.jit.export def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. @@ -318,6 +327,7 @@ def get_sel_type(self) -> List[int]: ) ).tolist() + @torch.jit.export def is_aparam_nall(self) -> bool: """Check whether the shape of atomic parameters is (nframes, nall, ndim). diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index 7f24ffdc53..adc682a41f 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -178,13 +178,11 @@ def test_self_consistency(self): def test_jit(self): md1 = torch.jit.script(self.md1) - # atomic model no more export methods - # self.assertEqual(md1.get_rcut(), self.rcut) - # self.assertEqual(md1.get_type_map(), ["foo", "bar"]) + self.assertEqual(md1.get_rcut(), self.rcut) + self.assertEqual(md1.get_type_map(), ["foo", "bar"]) md3 = torch.jit.script(self.md3) - # atomic model no more export methods - # self.assertEqual(md3.get_rcut(), self.rcut) - # self.assertEqual(md3.get_type_map(), ["foo", "bar"]) + self.assertEqual(md3.get_rcut(), self.rcut) + self.assertEqual(md3.get_type_map(), ["foo", "bar"]) class TestRemmapMethod(unittest.TestCase):