Skip to content

Commit

Permalink
test(pt): add common test case for model/atomic model (#3767)
Browse files Browse the repository at this point in the history
Fix #3501. Fix #3517. Fix #3518.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **Tests**
- Expanded testing capabilities for atomic and energy models to improve
accuracy and reliability in energy calculations.
- Implemented new test cases for atomic and energy models, along with
common model test cases, to validate diverse functionalities and
calculations.
- Introduced test case classes for atomic and energy models with methods
to assess parameters, types, outputs, and forward computations.
- Added utility functions for testing PyTorch-based deep learning models
with a custom backend.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored May 23, 2024
1 parent 591b94b commit dd97895
Show file tree
Hide file tree
Showing 27 changed files with 676 additions and 1 deletion.
23 changes: 22 additions & 1 deletion deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand All @@ -25,7 +28,7 @@
BaseAtomicModel_ = make_base_atomic_model(np.ndarray)


class BaseAtomicModel(BaseAtomicModel_):
class BaseAtomicModel(BaseAtomicModel_, NativeOP):
def __init__(
self,
type_map: List[str],
Expand Down Expand Up @@ -183,6 +186,24 @@ def forward_common_atomic(

return ret_dict

def call(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
return self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

def serialize(self) -> dict:
return {
"type_map": self.type_map,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def call_lower(
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

def input_type_cast(
self,
coord: np.ndarray,
Expand Down Expand Up @@ -473,4 +475,8 @@ def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()

def get_ntypes(self) -> int:
"""Get the number of types."""
return len(self.get_type_map())

return CM
32 changes: 32 additions & 0 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,42 @@
import numpy as np

from .region import (
normalize_coord,
to_face_distance,
)


def extend_input_and_build_neighbor_list(
coord,
atype,
rcut: float,
sel: List[int],
mixed_types: bool = False,
box: Optional[np.ndarray] = None,
):
nframes, nloc = atype.shape[:2]
if box is not None:
coord_normalized = normalize_coord(
coord.reshape(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
)
else:
coord_normalized = coord
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
distinguish_types=(not mixed_types),
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
return extended_coord, extended_atype, mapping, nlist


## translated from torch implemantation by chatgpt
def build_neighbor_list(
coord: np.ndarray,
Expand Down
20 changes: 20 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,26 @@ def forward_common_atomic(

return ret_dict

def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
return self.forward_common_atomic(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)

def serialize(self) -> dict:
return {
"type_map": self.type_map,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ banned-module-level-imports = [
"deepmd/pt/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
"source/ipi/tests/**" = ["TID253"]
"source/lmp/tests/**" = ["TID253"]
"**/*.ipynb" = ["T20"] # printing in a nb file is expected
Expand Down
2 changes: 2 additions & 0 deletions source/tests/universal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Universal tests for the project."""
1 change: 1 addition & 0 deletions source/tests/universal/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
23 changes: 23 additions & 0 deletions source/tests/universal/common/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Common test case."""

from abc import (
ABC,
abstractmethod,
)


class BackendTestCase(ABC):
"""Backend test case."""

module: object
"""Module to test."""

@property
@abstractmethod
def modules_to_test(self) -> list:
pass

@abstractmethod
def forward_wrapper(self, x):
pass
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions source/tests/universal/common/cases/atomic_model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
AtomicModelTestCase,
)


class EnerAtomicModelTest(AtomicModelTestCase):
def setUp(self) -> None:
self.expected_rcut = 5.0
self.expected_type_map = ["foo", "bar"]
self.expected_dim_fparam = 0
self.expected_dim_aparam = 0
self.expected_sel_type = [0, 1]
self.expected_aparam_nall = False
self.expected_model_output_type = ["energy", "mask"]
self.expected_sel = [8, 12]
116 changes: 116 additions & 0 deletions source/tests/universal/common/cases/atomic_model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Callable,
List,
)

import numpy as np

from deepmd.dpmodel.utils.nlist import (
extend_input_and_build_neighbor_list,
)


class AtomicModelTestCase:
"""Common test case for atomic model."""

expected_type_map: List[str]
"""Expected type map."""
expected_rcut: float
"""Expected cut-off radius."""
expected_dim_fparam: int
"""Expected number (dimension) of frame parameters."""
expected_dim_aparam: int
"""Expected number (dimension) of atomic parameters."""
expected_sel_type: List[int]
"""Expected selected atom types."""
expected_aparam_nall: bool
"""Expected shape of atomic parameters."""
expected_model_output_type: List[str]
"""Expected output type for the model."""
expected_sel: List[int]
"""Expected number of neighbors."""
forward_wrapper: Callable[[Any], Any]
"""Calss wrapper for forward method."""

def test_get_type_map(self):
"""Test get_type_map."""
for module in self.modules_to_test:
self.assertEqual(module.get_type_map(), self.expected_type_map)

def test_get_rcut(self):
"""Test get_rcut."""
for module in self.modules_to_test:
self.assertAlmostEqual(module.get_rcut(), self.expected_rcut)

def test_get_dim_fparam(self):
"""Test get_dim_fparam."""
for module in self.modules_to_test:
self.assertEqual(module.get_dim_fparam(), self.expected_dim_fparam)

def test_get_dim_aparam(self):
"""Test get_dim_aparam."""
for module in self.modules_to_test:
self.assertEqual(module.get_dim_aparam(), self.expected_dim_aparam)

def test_get_sel_type(self):
"""Test get_sel_type."""
for module in self.modules_to_test:
self.assertEqual(module.get_sel_type(), self.expected_sel_type)

def test_is_aparam_nall(self):
"""Test is_aparam_nall."""
for module in self.modules_to_test:
self.assertEqual(module.is_aparam_nall(), self.expected_aparam_nall)

def test_get_nnei(self):
"""Test get_nnei."""
expected_nnei = sum(self.expected_sel)
for module in self.modules_to_test:
self.assertEqual(module.get_nnei(), expected_nnei)

def test_get_ntypes(self):
"""Test get_ntypes."""
for module in self.modules_to_test:
self.assertEqual(module.get_ntypes(), len(self.expected_type_map))

def test_forward(self):
"""Test forward."""
nf = 1
coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([nf, -1])
atype = np.array([0, 0, 1], dtype=int).reshape([nf, -1])
cell = 6.0 * np.eye(3).reshape([nf, 9])
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
coord,
atype,
self.expected_rcut,
self.expected_sel,
mixed_types=True,
box=cell,
)
ret_lower = []
for module in self.modules_to_test:
module = self.forward_wrapper(module)

ret_lower.append(module(coord_ext, atype_ext, nlist))
for kk in ret_lower[0].keys():
subret = []
for rr in ret_lower:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions source/tests/universal/common/cases/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
ModelTestCase,
)


class EnerModelTest(ModelTestCase):
def setUp(self) -> None:
self.expected_rcut = 5.0
self.expected_type_map = ["foo", "bar"]
self.expected_dim_fparam = 0
self.expected_dim_aparam = 0
self.expected_sel_type = [0, 1]
self.expected_aparam_nall = False
self.expected_model_output_type = ["energy", "mask"]
self.expected_sel = [8, 12]
Loading

0 comments on commit dd97895

Please sign in to comment.