From 0f30e01e8154783260f866aff157c112ecff0b5b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 21:15:53 -0500 Subject: [PATCH] add test for lower Signed-off-by: Jinzhe Zeng --- source/tests/consistent/model/test_ener.py | 220 ++++++++++++++++++++- 1 file changed, 219 insertions(+), 1 deletion(-) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index ec73c57fa8..5d0253c5e8 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -6,8 +6,18 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, ) @@ -27,7 +37,8 @@ if INSTALLED_PT: from deepmd.pt.model.model import get_model as get_model_pt from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT - + from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy + from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch else: EnergyModelPT = None if INSTALLED_TF: @@ -39,6 +50,9 @@ ) if INSTALLED_JAX: + from deepmd.jax.common import ( + to_jax_array, + ) from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX from deepmd.jax.model.model import get_model as get_model_jax else: @@ -243,3 +257,207 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["energy_derv_c"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEnerLower(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H"], + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = EnergyModelTF + dp_class = EnergyModelDP + pt_class = EnergyModelPT + jax_class = EnergyModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + # TF does not have lower interface + return True + + @property + def skip_jax(self): + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is EnergyModelDP: + return get_model_dp(data) + elif cls is EnergyModelPT: + return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + + rcut = 6.0 + nframes, nloc = atype.shape[:2] + coord_normalized = normalize_coord( + coords.reshape(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + 6.0, + [20, 20], + distinguish_types=True, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + self.nlist = nlist + self.extended_coord = extended_coord + self.extended_atype = extended_atype + self.mapping = mapping + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj.call_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + do_atomic_virial=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return { + kk: torch_to_numpy(vv) + for kk, vv in pt_obj.forward_lower( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + numpy_to_torch(self.mapping), + do_atomic_virial=True, + ).items() + } + + def eval_jax(self, jax_obj: Any) -> Any: + return { + kk: to_numpy_array(vv) + for kk, vv in jax_obj.call_lower( + to_jax_array(self.extended_coord), + to_jax_array(self.extended_atype), + to_jax_array(self.nlist), + to_jax_array(self.mapping), + do_atomic_virial=True, + ).items() + } + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_force"].ravel(), + ret["virial"].ravel(), + ret["extended_virial"].ravel(), + ) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}")