From fd600d729fa62f356bcd3fab6689fc910d0f3847 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 29 Feb 2024 19:28:57 -0500 Subject: [PATCH] Hybrid descriptor (#3365) Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/__init__.py | 4 + deepmd/dpmodel/descriptor/hybrid.py | 242 ++++++++++++++++++ deepmd/pt/model/descriptor/__init__.py | 2 + deepmd/pt/model/descriptor/hybrid.py | 239 +++++++++++++++++ deepmd/pt/model/descriptor/se_atten.py | 2 +- deepmd/pt/utils/utils.py | 2 +- deepmd/tf/descriptor/hybrid.py | 38 ++- deepmd/utils/argcheck.py | 2 +- doc/model/train-hybrid.md | 4 +- .../consistent/descriptor/test_hybrid.py | 137 ++++++++++ .../tests/pt/model/test_descriptor_hybrid.py | 93 +++++++ 11 files changed, 758 insertions(+), 7 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/hybrid.py create mode 100644 source/tests/consistent/descriptor/test_hybrid.py create mode 100644 source/tests/pt/model/test_descriptor_hybrid.py diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index 08f8eb5052..a19a2aa034 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .hybrid import ( + DescrptHybrid, +) from .make_base_descriptor import ( make_base_descriptor, ) @@ -12,5 +15,6 @@ __all__ = [ "DescrptSeA", "DescrptSeR", + "DescrptHybrid", "make_base_descriptor", ] diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py new file mode 100644 index 0000000000..d2620fdcf7 --- /dev/null +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Dict, + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.dpmodel.utils.nlist import ( + nlist_distinguish_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@BaseDescriptor.register("hybrid") +class DescrptHybrid(BaseDescriptor, NativeOP): + """Concate a list of descriptors to form a new descriptor. + + Parameters + ---------- + list : list : List[Union[BaseDescriptor, Dict[str, Any]]] + Build a descriptor from the concatenation of the list of descriptors. + The descriptor can be either an object or a dictionary. + """ + + def __init__( + self, + list: List[Union[BaseDescriptor, Dict[str, Any]]], + ) -> None: + super().__init__() + # warning: list is conflict with built-in list + descrpt_list = list + if descrpt_list == [] or descrpt_list is None: + raise RuntimeError( + "cannot build descriptor from an empty list of descriptors." + ) + formatted_descript_list = [] + for ii in descrpt_list: + if isinstance(ii, BaseDescriptor): + formatted_descript_list.append(ii) + elif isinstance(ii, dict): + formatted_descript_list.append(BaseDescriptor(**ii)) + else: + raise NotImplementedError + self.descrpt_list = formatted_descript_list + self.numb_descrpt = len(self.descrpt_list) + for ii in range(1, self.numb_descrpt): + assert ( + self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes() + ), f"number of atom types in {ii}th descrptor {self.descrpt_list[0].__class__.__name__} does not match others" + # if hybrid sel is larger than sub sel, the nlist needs to be cut for each type + hybrid_sel = self.get_sel() + self.nlist_cut_idx: List[np.ndarray] = [] + if self.mixed_types() and not all( + descrpt.mixed_types() for descrpt in self.descrpt_list + ): + self.sel_no_mixed_types = np.max( + [ + descrpt.get_sel() + for descrpt in self.descrpt_list + if not descrpt.mixed_types() + ], + axis=0, + ).tolist() + else: + self.sel_no_mixed_types = None + for ii in range(self.numb_descrpt): + if self.mixed_types() == self.descrpt_list[ii].mixed_types(): + hybrid_sel = self.get_sel() + else: + assert self.sel_no_mixed_types is not None + hybrid_sel = self.sel_no_mixed_types + sub_sel = self.descrpt_list[ii].get_sel() + start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1] + end_idx = start_idx + np.array(sub_sel) + cut_idx = np.concatenate( + [range(ss, ee) for ss, ee in zip(start_idx, end_idx)] + ) + self.nlist_cut_idx.append(cut_idx) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item() + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + if self.mixed_types(): + return [ + np.max( + [descrpt.get_nsel() for descrpt in self.descrpt_list], axis=0 + ).item() + ] + else: + return np.max( + [descrpt.get_sel() for descrpt in self.descrpt_list], axis=0 + ).tolist() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.descrpt_list[0].get_ntypes() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return np.sum([descrpt.get_dim_out() for descrpt in self.descrpt_list]).item() + + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + return np.sum([descrpt.get_dim_emb() for descrpt in self.descrpt_list]).item() + + def mixed_types(self): + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + for descrpt in self.descrpt_list: + descrpt.compute_input_stats(merged, path) + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3. + g2 + The rotationally invariant pair-partical representation. + h2 + The rotationally equivariant pair-partical representation. + sw + The smooth switch function. + """ + out_descriptor = [] + out_gr = [] + out_g2 = [] + out_h2 = None + out_sw = None + if self.sel_no_mixed_types is not None: + nl_distinguish_types = nlist_distinguish_types( + nlist, + atype_ext, + self.sel_no_mixed_types, + ) + else: + nl_distinguish_types = None + for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx): + # cut the nlist to the correct length + if self.mixed_types() == descrpt.mixed_types(): + nl = nlist[:, :, nci] + else: + # mixed_types is True, but descrpt.mixed_types is False + assert nl_distinguish_types is not None + nl = nl_distinguish_types[:, :, nci] + odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping) + out_descriptor.append(odescriptor) + if gr is not None: + out_gr.append(gr) + if g2 is not None: + out_g2.append(g2) + if self.get_rcut() == descrpt.get_rcut(): + out_h2 = h2 + out_sw = sw + + out_descriptor = np.concatenate(out_descriptor, axis=-1) + out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None + out_g2 = np.concatenate(out_g2, axis=-1) if out_g2 else None + return out_descriptor, out_gr, out_g2, out_h2, out_sw + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["list"] = [ + BaseDescriptor.update_sel(global_jdata, sub_jdata) + for sub_jdata in local_jdata["list"] + ] + return local_jdata_cpy + + def serialize(self) -> dict: + return { + "@class": "Descriptor", + "type": "hybrid", + "@version": 1, + "list": [descrpt.serialize() for descrpt in self.descrpt_list], + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptHybrid": + data = data.copy() + class_name = data.pop("@class") + assert class_name == "Descriptor" + class_type = data.pop("type") + assert class_type == "hybrid" + check_version_compatibility(data.pop("@version"), 1, 1) + obj = cls( + list=[BaseDescriptor.deserialize(ii) for ii in data["list"]], + ) + return obj diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 5fd644f149..72f734de04 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -18,6 +18,7 @@ ) from .hybrid import ( DescrptBlockHybrid, + DescrptHybrid, ) from .repformers import ( DescrptBlockRepformers, @@ -39,6 +40,7 @@ "DescrptSeR", "DescrptDPA1", "DescrptDPA2", + "DescrptHybrid", "prod_env_mat", "DescrptGaussianLcc", "DescrptBlockHybrid", diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 688d448b81..5aa83ef534 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -1,21 +1,260 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, + Dict, List, Optional, + Union, ) +import numpy as np import torch from deepmd.pt.model.descriptor import ( DescriptorBlock, ) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.pt.model.network.network import ( Identity, Linear, ) +from deepmd.pt.utils.nlist import ( + nlist_distinguish_types, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@BaseDescriptor.register("hybrid") +class DescrptHybrid(BaseDescriptor, torch.nn.Module): + """Concate a list of descriptors to form a new descriptor. + + Parameters + ---------- + list : list : List[Union[BaseDescriptor, Dict[str, Any]]] + Build a descriptor from the concatenation of the list of descriptors. + The descriptor can be either an object or a dictionary. + """ + + def __init__( + self, + list: List[Union[BaseDescriptor, Dict[str, Any]]], + **kwargs, + ) -> None: + super().__init__() + # warning: list is conflict with built-in list + descrpt_list = list + if descrpt_list == [] or descrpt_list is None: + raise RuntimeError( + "cannot build descriptor from an empty list of descriptors." + ) + formatted_descript_list: List[BaseDescriptor] = [] + for ii in descrpt_list: + if isinstance(ii, BaseDescriptor): + formatted_descript_list.append(ii) + elif isinstance(ii, dict): + formatted_descript_list.append( + # pass other arguments (e.g. ntypes) to the descriptor + BaseDescriptor(**ii, **kwargs) + ) + else: + raise NotImplementedError + self.descrpt_list = torch.nn.ModuleList(formatted_descript_list) + self.numb_descrpt = len(self.descrpt_list) + for ii in range(1, self.numb_descrpt): + assert ( + self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes() + ), f"number of atom types in {ii}th descrptor does not match others" + # if hybrid sel is larger than sub sel, the nlist needs to be cut for each type + self.nlist_cut_idx: List[torch.Tensor] = [] + if self.mixed_types() and not all( + descrpt.mixed_types() for descrpt in self.descrpt_list + ): + self.sel_no_mixed_types = np.max( + [ + descrpt.get_sel() + for descrpt in self.descrpt_list + if not descrpt.mixed_types() + ], + axis=0, + ).tolist() + else: + self.sel_no_mixed_types = None + for ii in range(self.numb_descrpt): + if self.mixed_types() == self.descrpt_list[ii].mixed_types(): + hybrid_sel = self.get_sel() + else: + assert self.sel_no_mixed_types is not None + hybrid_sel = self.sel_no_mixed_types + sub_sel = self.descrpt_list[ii].get_sel() + start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1] + end_idx = start_idx + np.array(sub_sel) + cut_idx = np.concatenate( + [range(ss, ee) for ss, ee in zip(start_idx, end_idx)] + ).astype(np.int64) + self.nlist_cut_idx.append(to_torch_tensor(cut_idx)) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + # do not use numpy here - jit is not happy + return max([descrpt.get_rcut() for descrpt in self.descrpt_list]) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + if self.mixed_types(): + return [ + np.max( + [descrpt.get_nsel() for descrpt in self.descrpt_list], axis=0 + ).item() + ] + else: + return np.max( + [descrpt.get_sel() for descrpt in self.descrpt_list], axis=0 + ).tolist() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.descrpt_list[0].get_ntypes() + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return sum([descrpt.get_dim_out() for descrpt in self.descrpt_list]) + + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + return sum([descrpt.get_dim_emb() for descrpt in self.descrpt_list]) + + def mixed_types(self): + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + for descrpt in self.descrpt_list: + descrpt.compute_input_stats(merged, path) + + def forward( + self, + coord_ext: torch.Tensor, + atype_ext: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3. This descriptor returns None + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. this descriptor returns None + """ + out_descriptor = [] + out_gr = [] + out_g2 = [] + out_h2: Optional[torch.Tensor] = None + out_sw: Optional[torch.Tensor] = None + if self.sel_no_mixed_types is not None: + nl_distinguish_types = nlist_distinguish_types( + nlist, + atype_ext, + self.sel_no_mixed_types, + ) + else: + nl_distinguish_types = None + # make jit happy + # for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx): + for ii, descrpt in enumerate(self.descrpt_list): + # cut the nlist to the correct length + if self.mixed_types() == descrpt.mixed_types(): + nl = nlist[:, :, self.nlist_cut_idx[ii]] + else: + # mixed_types is True, but descrpt.mixed_types is False + assert nl_distinguish_types is not None + nl = nl_distinguish_types[:, :, self.nlist_cut_idx[ii]] + odescriptor, gr, g2, h2, sw = descrpt(coord_ext, atype_ext, nl, mapping) + out_descriptor.append(odescriptor) + if gr is not None: + out_gr.append(gr) + if g2 is not None: + out_g2.append(g2) + if self.get_rcut() == descrpt.get_rcut(): + out_h2 = h2 + out_sw = sw + out_descriptor = torch.cat(out_descriptor, dim=-1) + out_gr = torch.cat(out_gr, dim=-2) if out_gr else None + out_g2 = torch.cat(out_g2, dim=-1) if out_g2 else None + return out_descriptor, out_gr, out_g2, out_h2, out_sw + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["list"] = [ + BaseDescriptor.update_sel(global_jdata, sub_jdata) + for sub_jdata in local_jdata["list"] + ] + return local_jdata_cpy + + def serialize(self) -> dict: + return { + "@class": "Descriptor", + "type": "hybrid", + "@version": 1, + "list": [descrpt.serialize() for descrpt in self.descrpt_list], + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptHybrid": + data = data.copy() + class_name = data.pop("@class") + assert class_name == "Descriptor" + class_type = data.pop("type") + assert class_type == "hybrid" + check_version_compatibility(data.pop("@version"), 1, 1) + obj = cls( + list=[BaseDescriptor.deserialize(ii) for ii in data["list"]], + ) + return obj @DescriptorBlock.register("hybrid") diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index a2197213ad..c815cda013 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -303,7 +303,7 @@ def forward( result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], - rot_mat.view(-1, self.filter_neuron[-1], 3), + rot_mat.view(-1, nloc, self.filter_neuron[-1], 3), sw, ) diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 852c42cd0c..f5a4cd84b6 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -97,7 +97,7 @@ def to_torch_tensor( # Create a reverse mapping of NP_PRECISION_DICT reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()} # Use the reverse mapping to find keys with the desired value - prec = reverse_precision_dict.get(type(xx.flat[0]), None) + prec = reverse_precision_dict.get(xx.dtype.type, None) prec = PT_PRECISION_DICT.get(prec, None) if prec is None: raise ValueError(f"unknown precision {xx.dtype}") diff --git a/deepmd/tf/descriptor/hybrid.py b/deepmd/tf/descriptor/hybrid.py index 8ce8acc4db..4e7eaa2c92 100644 --- a/deepmd/tf/descriptor/hybrid.py +++ b/deepmd/tf/descriptor/hybrid.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, + Dict, List, Optional, Tuple, + Union, ) import numpy as np @@ -14,6 +17,9 @@ from deepmd.tf.utils.spin import ( Spin, ) +from deepmd.utils.version import ( + check_version_compatibility, +) # from deepmd.tf.descriptor import DescrptLocFrame # from deepmd.tf.descriptor import DescrptSeA @@ -32,13 +38,14 @@ class DescrptHybrid(Descriptor): Parameters ---------- - list : list + list : list : List[Union[Descriptor, Dict[str, Any]]] Build a descriptor from the concatenation of the list of descriptors. + The descriptor can be either an object or a dictionary. """ def __init__( self, - list: list, + list: List[Union[Descriptor, Dict[str, Any]]], multi_task: bool = False, ntypes: Optional[int] = None, spin: Optional[Spin] = None, @@ -434,3 +441,30 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): for sub_jdata in local_jdata["list"] ] return local_jdata_cpy + + def serialize(self, suffix: str = "") -> dict: + return { + "@class": "Descriptor", + "type": "hybrid", + "@version": 1, + "list": [ + descrpt.serialize(suffix=f"{suffix}_{idx}") + for idx, descrpt in enumerate(self.descrpt_list) + ], + } + + @classmethod + def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid": + data = data.copy() + class_name = data.pop("@class") + assert class_name == "Descriptor" + class_type = data.pop("type") + assert class_type == "hybrid" + check_version_compatibility(data.pop("@version"), 1, 1) + obj = cls( + list=[ + Descriptor.deserialize(ii, suffix=f"{suffix}_{idx}") + for idx, ii in enumerate(data["list"]) + ], + ) + return obj diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8e3196cba1..89b341491e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -353,7 +353,7 @@ def descrpt_se_r_args(): ] -@descrpt_args_plugin.register("hybrid", doc=doc_only_tf_supported) +@descrpt_args_plugin.register("hybrid") def descrpt_hybrid_args(): doc_list = "A list of descriptor definitions" diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index 1db3f49a1f..3014aa869f 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -1,7 +1,7 @@ -# Descriptor `"hybrid"` {{ tensorflow_icon }} +# Descriptor `"hybrid"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DPModel {{ dpmodel_icon }} ::: This descriptor hybridizes multiple descriptors to form a new descriptor. For example, we have a list of descriptors denoted by $\mathcal D_1$, $\mathcal D_2$, ..., $\mathcal D_N$, the hybrid descriptor this the concatenation of the list, i.e. $\mathcal D = (\mathcal D_1, \mathcal D_2, \cdots, \mathcal D_N)$. diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py new file mode 100644 index 0000000000..7cfb627d54 --- /dev/null +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.hybrid import DescrptHybrid as DescrptHybridPT +else: + DescrptHybridPT = None +if INSTALLED_TF: + from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF +else: + DescrptHybridTF = None +from deepmd.utils.argcheck import ( + descrpt_hybrid_args, +) + + +class TestHybrid(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + return { + "list": [ + { + "type": "se_e2_r", + # test the case that sel are different! + "sel": [10, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": False, + "type_one_side": True, + "precision": "float64", + "seed": 20240229, + }, + { + "type": "se_e2_a", + "sel": [9, 11], + "rcut_smth": 2.80, + "rcut": 3.00, + "neuron": [6, 12, 24], + "axis_neuron": 3, + "resnet_dt": True, + "type_one_side": True, + "precision": "float64", + "seed": 20240229, + }, + ] + } + + tf_class = DescrptHybridTF + dp_class = DescrptHybridDP + pt_class = DescrptHybridPT + args = descrpt_hybrid_args() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) diff --git a/source/tests/pt/model/test_descriptor_hybrid.py b/source/tests/pt/model/test_descriptor_hybrid.py new file mode 100644 index 0000000000..6742388bd9 --- /dev/null +++ b/source/tests/pt/model/test_descriptor_hybrid.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.descriptor.hybrid import ( + DescrptHybrid, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptHybrid(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_jit( + self, + ): + ddsub0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + old_impl=False, + ) + ddsub1 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + ) + dd0 = DescrptHybrid(list=[ddsub0, ddsub1]) + dd1 = DescrptHybrid.deserialize(dd0.serialize()) + dd0 = torch.jit.script(dd0) + dd1 = torch.jit.script(dd1) + + def test_hybrid_mixed_and_no_mixed(self): + coord_ext = to_torch_tensor(self.coord_ext) + atype_ext = to_torch_tensor(self.atype_ext) + nlist1 = to_torch_tensor(self.nlist) + nlist2 = to_torch_tensor(-np.sort(-self.nlist, axis=-1)) + ddsub0 = DescrptSeA( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=self.sel, + ) + ddsub1 = DescrptDPA1( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + sel=np.sum(self.sel).item() - 1, + ntypes=len(self.sel), + ) + ddsub2 = DescrptSeR( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + sel=[3, 1], + ) + dd = DescrptHybrid(list=[ddsub0, ddsub1, ddsub2]) + ret = dd( + coord_ext, + atype_ext, + nlist2, + ) + ret0 = ddsub0( + coord_ext, + atype_ext, + nlist1, + ) + ret1 = ddsub1(coord_ext, atype_ext, nlist2[:, :, :-1]) + ret2 = ddsub2(coord_ext, atype_ext, nlist1[:, :, [0, 1, 2, self.sel[0]]]) + torch.testing.assert_close( + ret[0], + torch.cat([ret0[0], ret1[0], ret2[0]], dim=2), + )