Skip to content

Commit

Permalink
fix: revert atomic change
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 22, 2024
1 parent a4201fe commit c0c1dba
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
11 changes: 1 addition & 10 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,10 @@ def mixed_types(self) -> bool:
"""
return True

@torch.jit.export
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

@torch.jit.export

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map
Expand All @@ -107,16 +105,13 @@ def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]

@torch.jit.export
def get_sel(self) -> List[int]:
return [max([model.get_nsel() for model in self.models])]

@torch.jit.export
def get_model_nsels(self) -> List[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]

@torch.jit.export
def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]
Expand Down Expand Up @@ -295,18 +290,15 @@ def _compute_weight(
for _ in range(nmodels)
]

@torch.jit.export
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# tricky...
return max([model.get_dim_fparam() for model in self.models])

@torch.jit.export
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return max([model.get_dim_aparam() for model in self.models])

@torch.jit.export
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Expand All @@ -327,7 +319,6 @@ def get_sel_type(self) -> List[int]:
)
).tolist()

@torch.jit.export
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
Expand Down
10 changes: 8 additions & 2 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,17 @@ def test_self_consistency(self):

def test_jit(self):
md1 = torch.jit.script(self.md1)
self.assertEqual(md1.get_rcut(), self.rcut)
self.assertEqual(md1.get_type_map(), ["foo", "bar"])
# atomic model no more export methods
# self.assertEqual(md1.get_rcut(), self.rcut)
# self.assertEqual(md1.get_type_map(), ["foo", "bar"])
# md3 is the model, not atomic model
md3 = torch.jit.script(self.md3)
self.assertEqual(md3.get_rcut(), self.rcut)
self.assertEqual(md3.get_type_map(), ["foo", "bar"])
self.assertEqual(md3.get_sel(), [sum(self.sel)])
self.assertEqual(md3.get_dim_aparam(), 0)
self.assertEqual(md3.get_dim_fparam(), 0)
self.assertEqual(md3.is_aparam_nall(), False)


class TestRemmapMethod(unittest.TestCase):
Expand Down

0 comments on commit c0c1dba

Please sign in to comment.