diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 9e43851157..c16749405d 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -9,6 +9,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.dpmodel.output_def import ( FittingOutputDef, OutputVariableDef, @@ -25,7 +28,7 @@ BaseAtomicModel_ = make_base_atomic_model(np.ndarray) -class BaseAtomicModel(BaseAtomicModel_): +class BaseAtomicModel(BaseAtomicModel_, NativeOP): def __init__( self, type_map: List[str], @@ -183,6 +186,24 @@ def forward_common_atomic( return ret_dict + def call( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + mapping: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + ) -> Dict[str, np.ndarray]: + return self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + def serialize(self) -> dict: return { "type_map": self.type_map, diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 68889ad331..7993f10abd 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -236,6 +236,8 @@ def call_lower( model_predict = self.output_type_cast(model_predict, input_prec) return model_predict + forward_lower = call_lower + def input_type_cast( self, coord: np.ndarray, @@ -473,4 +475,8 @@ def atomic_output_def(self) -> FittingOutputDef: """Get the output def of the atomic model.""" return self.atomic_model.atomic_output_def() + def get_ntypes(self) -> int: + """Get the number of types.""" + return len(self.get_type_map()) + return CM diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index ca8b18023b..018f50f1a5 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -9,10 +9,42 @@ import numpy as np from .region import ( + normalize_coord, to_face_distance, ) +def extend_input_and_build_neighbor_list( + coord, + atype, + rcut: float, + sel: List[int], + mixed_types: bool = False, + box: Optional[np.ndarray] = None, +): + nframes, nloc = atype.shape[:2] + if box is not None: + coord_normalized = normalize_coord( + coord.reshape(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + else: + coord_normalized = coord + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=(not mixed_types), + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + return extended_coord, extended_atype, mapping, nlist + + ## translated from torch implemantation by chatgpt def build_neighbor_list( coord: np.ndarray, diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 3be052919d..1340028425 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -256,6 +256,26 @@ def forward_common_atomic( return ret_dict + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + return self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + ) + def serialize(self) -> dict: return { "type_map": self.type_map, diff --git a/pyproject.toml b/pyproject.toml index 80d5ad9ee9..7703ce71f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -351,6 +351,7 @@ banned-module-level-imports = [ "deepmd/pt/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] +"source/tests/universal/pt/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] "**/*.ipynb" = ["T20"] # printing in a nb file is expected diff --git a/source/tests/universal/__init__.py b/source/tests/universal/__init__.py new file mode 100644 index 0000000000..3c8d925dcc --- /dev/null +++ b/source/tests/universal/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Universal tests for the project.""" diff --git a/source/tests/universal/common/__init__.py b/source/tests/universal/common/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/backend.py b/source/tests/universal/common/backend.py new file mode 100644 index 0000000000..d5747b77b7 --- /dev/null +++ b/source/tests/universal/common/backend.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Common test case.""" + +from abc import ( + ABC, + abstractmethod, +) + + +class BackendTestCase(ABC): + """Backend test case.""" + + module: object + """Module to test.""" + + @property + @abstractmethod + def modules_to_test(self) -> list: + pass + + @abstractmethod + def forward_wrapper(self, x): + pass diff --git a/source/tests/universal/common/cases/__init__.py b/source/tests/universal/common/cases/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/atomic_model/__init__.py b/source/tests/universal/common/cases/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/atomic_model/ener_model.py b/source/tests/universal/common/cases/atomic_model/ener_model.py new file mode 100644 index 0000000000..0f1daaf87b --- /dev/null +++ b/source/tests/universal/common/cases/atomic_model/ener_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + AtomicModelTestCase, +) + + +class EnerAtomicModelTest(AtomicModelTestCase): + def setUp(self) -> None: + self.expected_rcut = 5.0 + self.expected_type_map = ["foo", "bar"] + self.expected_dim_fparam = 0 + self.expected_dim_aparam = 0 + self.expected_sel_type = [0, 1] + self.expected_aparam_nall = False + self.expected_model_output_type = ["energy", "mask"] + self.expected_sel = [8, 12] diff --git a/source/tests/universal/common/cases/atomic_model/utils.py b/source/tests/universal/common/cases/atomic_model/utils.py new file mode 100644 index 0000000000..3b5fc64fda --- /dev/null +++ b/source/tests/universal/common/cases/atomic_model/utils.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Callable, + List, +) + +import numpy as np + +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +class AtomicModelTestCase: + """Common test case for atomic model.""" + + expected_type_map: List[str] + """Expected type map.""" + expected_rcut: float + """Expected cut-off radius.""" + expected_dim_fparam: int + """Expected number (dimension) of frame parameters.""" + expected_dim_aparam: int + """Expected number (dimension) of atomic parameters.""" + expected_sel_type: List[int] + """Expected selected atom types.""" + expected_aparam_nall: bool + """Expected shape of atomic parameters.""" + expected_model_output_type: List[str] + """Expected output type for the model.""" + expected_sel: List[int] + """Expected number of neighbors.""" + forward_wrapper: Callable[[Any], Any] + """Calss wrapper for forward method.""" + + def test_get_type_map(self): + """Test get_type_map.""" + for module in self.modules_to_test: + self.assertEqual(module.get_type_map(), self.expected_type_map) + + def test_get_rcut(self): + """Test get_rcut.""" + for module in self.modules_to_test: + self.assertAlmostEqual(module.get_rcut(), self.expected_rcut) + + def test_get_dim_fparam(self): + """Test get_dim_fparam.""" + for module in self.modules_to_test: + self.assertEqual(module.get_dim_fparam(), self.expected_dim_fparam) + + def test_get_dim_aparam(self): + """Test get_dim_aparam.""" + for module in self.modules_to_test: + self.assertEqual(module.get_dim_aparam(), self.expected_dim_aparam) + + def test_get_sel_type(self): + """Test get_sel_type.""" + for module in self.modules_to_test: + self.assertEqual(module.get_sel_type(), self.expected_sel_type) + + def test_is_aparam_nall(self): + """Test is_aparam_nall.""" + for module in self.modules_to_test: + self.assertEqual(module.is_aparam_nall(), self.expected_aparam_nall) + + def test_get_nnei(self): + """Test get_nnei.""" + expected_nnei = sum(self.expected_sel) + for module in self.modules_to_test: + self.assertEqual(module.get_nnei(), expected_nnei) + + def test_get_ntypes(self): + """Test get_ntypes.""" + for module in self.modules_to_test: + self.assertEqual(module.get_ntypes(), len(self.expected_type_map)) + + def test_forward(self): + """Test forward.""" + nf = 1 + coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([nf, -1]) + atype = np.array([0, 0, 1], dtype=int).reshape([nf, -1]) + cell = 6.0 * np.eye(3).reshape([nf, 9]) + coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( + coord, + atype, + self.expected_rcut, + self.expected_sel, + mixed_types=True, + box=cell, + ) + ret_lower = [] + for module in self.modules_to_test: + module = self.forward_wrapper(module) + + ret_lower.append(module(coord_ext, atype_ext, nlist)) + for kk in ret_lower[0].keys(): + subret = [] + for rr in ret_lower: + if rr is not None: + subret.append(rr[kk]) + if len(subret): + for ii, rr in enumerate(subret[1:]): + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) diff --git a/source/tests/universal/common/cases/model/__init__.py b/source/tests/universal/common/cases/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/model/ener_model.py b/source/tests/universal/common/cases/model/ener_model.py new file mode 100644 index 0000000000..35d44f9784 --- /dev/null +++ b/source/tests/universal/common/cases/model/ener_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + ModelTestCase, +) + + +class EnerModelTest(ModelTestCase): + def setUp(self) -> None: + self.expected_rcut = 5.0 + self.expected_type_map = ["foo", "bar"] + self.expected_dim_fparam = 0 + self.expected_dim_aparam = 0 + self.expected_sel_type = [0, 1] + self.expected_aparam_nall = False + self.expected_model_output_type = ["energy", "mask"] + self.expected_sel = [8, 12] diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py new file mode 100644 index 0000000000..d67ac8e80d --- /dev/null +++ b/source/tests/universal/common/cases/model/utils.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Callable, + List, +) + +import numpy as np + +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +class ModelTestCase: + """Common test case for model.""" + + expected_type_map: List[str] + """Expected type map.""" + expected_rcut: float + """Expected cut-off radius.""" + expected_dim_fparam: int + """Expected number (dimension) of frame parameters.""" + expected_dim_aparam: int + """Expected number (dimension) of atomic parameters.""" + expected_sel_type: List[int] + """Expected selected atom types.""" + expected_aparam_nall: bool + """Expected shape of atomic parameters.""" + expected_model_output_type: List[str] + """Expected output type for the model.""" + expected_sel: List[int] + """Expected number of neighbors.""" + forward_wrapper: Callable[[Any], Any] + """Calss wrapper for forward method.""" + + def test_get_type_map(self): + """Test get_type_map.""" + for module in self.modules_to_test: + self.assertEqual(module.get_type_map(), self.expected_type_map) + + def test_get_rcut(self): + """Test get_rcut.""" + for module in self.modules_to_test: + self.assertAlmostEqual(module.get_rcut(), self.expected_rcut) + + def test_get_dim_fparam(self): + """Test get_dim_fparam.""" + for module in self.modules_to_test: + self.assertEqual(module.get_dim_fparam(), self.expected_dim_fparam) + + def test_get_dim_aparam(self): + """Test get_dim_aparam.""" + for module in self.modules_to_test: + self.assertEqual(module.get_dim_aparam(), self.expected_dim_aparam) + + def test_get_sel_type(self): + """Test get_sel_type.""" + for module in self.modules_to_test: + self.assertEqual(module.get_sel_type(), self.expected_sel_type) + + def test_is_aparam_nall(self): + """Test is_aparam_nall.""" + for module in self.modules_to_test: + self.assertEqual(module.is_aparam_nall(), self.expected_aparam_nall) + + def test_model_output_type(self): + """Test model_output_type.""" + for module in self.modules_to_test: + self.assertEqual( + module.model_output_type(), self.expected_model_output_type + ) + + def test_get_nnei(self): + """Test get_nnei.""" + expected_nnei = sum(self.expected_sel) + for module in self.modules_to_test: + self.assertEqual(module.get_nnei(), expected_nnei) + + def test_get_ntypes(self): + """Test get_ntypes.""" + for module in self.modules_to_test: + self.assertEqual(module.get_ntypes(), len(self.expected_type_map)) + + def test_forward(self): + """Test forward and forward_lower.""" + nf = 1 + coord = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + dtype=np.float64, + ).reshape([nf, -1]) + atype = np.array([0, 0, 1], dtype=int).reshape([nf, -1]) + cell = 6.0 * np.eye(3).reshape([nf, 9]) + coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( + coord, + atype, + self.expected_rcut, + self.expected_sel, + mixed_types=True, + box=cell, + ) + ret = [] + ret_lower = [] + for module in self.modules_to_test: + module = self.forward_wrapper(module) + ret.append(module(coord, atype, cell)) + + ret_lower.append(module.forward_lower(coord_ext, atype_ext, nlist)) + for kk in ret[0].keys(): + subret = [] + for rr in ret: + if rr is not None: + subret.append(rr[kk]) + if len(subret): + for ii, rr in enumerate(subret[1:]): + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) + for kk in ret_lower[0].keys(): + subret = [] + for rr in ret_lower: + if rr is not None: + subret.append(rr[kk]) + if len(subret): + for ii, rr in enumerate(subret[1:]): + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) + same_keys = set(ret[0].keys()) & set(ret_lower[0].keys()) + self.assertTrue(same_keys) + for key in same_keys: + for rr in ret: + if rr[key] is not None: + rr1 = rr[key] + break + else: + continue + for rr in ret_lower: + if rr[key] is not None: + rr2 = rr[key] + break + else: + continue + np.testing.assert_allclose(rr1, rr2) diff --git a/source/tests/universal/dpmodel/__init__.py b/source/tests/universal/dpmodel/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/atomc_model/__init__.py b/source/tests/universal/dpmodel/atomc_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/atomc_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py b/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py new file mode 100644 index 0000000000..6cf4598646 --- /dev/null +++ b/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting.ener_fitting import ( + EnergyFittingNet, +) + +from ...common.cases.atomic_model.ener_model import ( + EnerAtomicModelTest, +) +from ..backend import ( + DPTestCase, +) + + +class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase): + def setUp(self): + EnerAtomicModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = DPAtomicModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/dpmodel/backend.py b/source/tests/universal/dpmodel/backend.py new file mode 100644 index 0000000000..61982fea98 --- /dev/null +++ b/source/tests/universal/dpmodel/backend.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.common import ( + NativeOP, +) + +from ..common.backend import ( + BackendTestCase, +) + + +class DPTestCase(BackendTestCase): + """Common test case.""" + + module: NativeOP + """DP module to test.""" + + def forward_wrapper(self, x): + return x + + @property + def deserialized_module(self): + return self.module.deserialize(self.module.serialize()) + + @property + def modules_to_test(self): + modules = [ + self.module, + self.deserialized_module, + ] + return modules diff --git a/source/tests/universal/dpmodel/model/__init__.py b/source/tests/universal/dpmodel/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/model/test_ener_model.py b/source/tests/universal/dpmodel/model/test_ener_model.py new file mode 100644 index 0000000000..506564260f --- /dev/null +++ b/source/tests/universal/dpmodel/model/test_ener_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting.ener_fitting import ( + EnergyFittingNet, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) + +from ...common.cases.model.ener_model import ( + EnerModelTest, +) +from ..backend import ( + DPTestCase, +) + + +class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase): + def setUp(self): + EnerModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = EnergyModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/pt/__init__.py b/source/tests/universal/pt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/atomc_model/__init__.py b/source/tests/universal/pt/atomc_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/atomc_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py b/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py new file mode 100644 index 0000000000..5ba3be0fad --- /dev/null +++ b/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.pt.model.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) + +from ...common.cases.atomic_model.ener_model import ( + EnerAtomicModelTest, +) +from ..backend import ( + PTTestCase, +) + + +class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, PTTestCase): + def setUp(self): + EnerAtomicModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = DPAtomicModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/pt/backend.py b/source/tests/universal/pt/backend.py new file mode 100644 index 0000000000..61110a0cc6 --- /dev/null +++ b/source/tests/universal/pt/backend.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +class PTTestCase: + """Common test case.""" + + module: "torch.nn.Module" + """PT module to test.""" + + @property + def script_module(self): + return torch.jit.script(self.module) + + @property + def deserialized_module(self): + return self.module.deserialize(self.module.serialize()) + + @property + def modules_to_test(self): + modules = [ + self.module, + self.deserialized_module, + ] + return modules + + def test_jit(self): + self.script_module + + def forward_wrapper(self, module): + def create_wrapper_method(method): + def wrapper_method(self, *args, **kwargs): + # convert to torch tensor + args = [to_torch_tensor(arg) for arg in args] + kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} + # forward + output = method(*args, **kwargs) + # convert to numpy array + if isinstance(output, tuple): + output = tuple(to_numpy_array(o) for o in output) + elif isinstance(output, dict): + output = {k: to_numpy_array(v) for k, v in output.items()} + else: + output = to_numpy_array(output) + return output + + return wrapper_method + + class wrapper_module: + __call__ = create_wrapper_method(module.__call__) + if hasattr(module, "forward_lower"): + forward_lower = create_wrapper_method(module.forward_lower) + + return wrapper_module() diff --git a/source/tests/universal/pt/model/__init__.py b/source/tests/universal/pt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/model/test_ener_model.py b/source/tests/universal/pt/model/test_ener_model.py new file mode 100644 index 0000000000..af5d77d5b4 --- /dev/null +++ b/source/tests/universal/pt/model/test_ener_model.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.ener_model import ( + EnergyModel, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) + +from ...common.cases.model.ener_model import ( + EnerModelTest, +) +from ..backend import ( + PTTestCase, +) + + +class TestEnergyModelDP(unittest.TestCase, EnerModelTest, PTTestCase): + @property + def modules_to_test(self): + # for Model, we can test script module API + modules = [ + *PTTestCase.modules_to_test.fget(self), + self.script_module, + ] + return modules + + def setUp(self): + EnerModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = EnergyModel( + ds, + ft, + type_map=self.expected_type_map, + )