Skip to content

Commit

Permalink
add test for lower
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 5, 2024
1 parent 71255cb commit 0f30e01
Showing 1 changed file with 219 additions and 1 deletion.
220 changes: 219 additions & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP
from deepmd.dpmodel.model.model import get_model as get_model_dp
from deepmd.dpmodel.utils.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.dpmodel.utils.region import (
normalize_coord,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
Expand All @@ -27,7 +37,8 @@
if INSTALLED_PT:
from deepmd.pt.model.model import get_model as get_model_pt
from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT

from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy
from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch
else:
EnergyModelPT = None
if INSTALLED_TF:
Expand All @@ -39,6 +50,9 @@
)

if INSTALLED_JAX:
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX
from deepmd.jax.model.model import get_model as get_model_jax
else:
Expand Down Expand Up @@ -243,3 +257,207 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret["energy_derv_c"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")


@parameterized(
(
[],
[[0, 1]],
),
(
[],
[1],
),
)
class TestEnerLower(CommonTest, ModelTest, unittest.TestCase):
@property
def data(self) -> dict:
pair_exclude_types, atom_exclude_types = self.param
return {
"type_map": ["O", "H"],
"pair_exclude_types": pair_exclude_types,
"atom_exclude_types": atom_exclude_types,
"descriptor": {
"type": "se_e2_a",
"sel": [20, 20],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [
3,
6,
],
"resnet_dt": False,
"axis_neuron": 2,
"precision": "float64",
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"neuron": [
5,
5,
],
"resnet_dt": True,
"precision": "float64",
"seed": 1,
},
}

tf_class = EnergyModelTF
dp_class = EnergyModelDP
pt_class = EnergyModelPT
jax_class = EnergyModelJAX
args = model_args()

def get_reference_backend(self):
"""Get the reference backend.
We need a reference backend that can reproduce forces.
"""
if not self.skip_pt:
return self.RefBackend.PT
if not self.skip_jax:
return self.RefBackend.JAX
if not self.skip_dp:
return self.RefBackend.DP
raise ValueError("No available reference")

@property
def skip_tf(self):
# TF does not have lower interface
return True

@property
def skip_jax(self):
return not INSTALLED_JAX

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
data = data.copy()
if cls is EnergyModelDP:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)

self.ntypes = 2
coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, -1, 3)
atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1)
box = np.array(
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, 9)

rcut = 6.0
nframes, nloc = atype.shape[:2]
coord_normalized = normalize_coord(
coords.reshape(nframes, nloc, 3),
box.reshape(nframes, 3, 3),
)
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, box, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
6.0,
[20, 20],
distinguish_types=True,
)
extended_coord = extended_coord.reshape(nframes, -1, 3)
self.nlist = nlist
self.extended_coord = extended_coord
self.extended_atype = extended_atype
self.mapping = mapping

def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
raise NotImplementedError("no TF in this test")

def eval_dp(self, dp_obj: Any) -> Any:
return dp_obj.call_lower(
self.extended_coord,
self.extended_atype,
self.nlist,
self.mapping,
do_atomic_virial=True,
)

def eval_pt(self, pt_obj: Any) -> Any:
return {
kk: torch_to_numpy(vv)
for kk, vv in pt_obj.forward_lower(
numpy_to_torch(self.extended_coord),
numpy_to_torch(self.extended_atype),
numpy_to_torch(self.nlist),
numpy_to_torch(self.mapping),
do_atomic_virial=True,
).items()
}

def eval_jax(self, jax_obj: Any) -> Any:
return {
kk: to_numpy_array(vv)
for kk, vv in jax_obj.call_lower(
to_jax_array(self.extended_coord),
to_jax_array(self.extended_atype),
to_jax_array(self.nlist),
to_jax_array(self.mapping),
do_atomic_virial=True,
).items()
}

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
# shape not matched. ravel...
if backend is self.RefBackend.DP:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
SKIP_FLAG,
SKIP_FLAG,
SKIP_FLAG,
)
elif backend is self.RefBackend.PT:
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["extended_force"].ravel(),
ret["virial"].ravel(),
ret["extended_virial"].ravel(),
)
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
ret["energy_derv_c_redu"].ravel(),
ret["energy_derv_c"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")

0 comments on commit 0f30e01

Please sign in to comment.