From 80d73d29eb1d9116cf362fc6ef259a5c95589ce2 Mon Sep 17 00:00:00 2001 From: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com> Date: Fri, 7 Jun 2024 01:09:10 +0800 Subject: [PATCH] Feat: add `se_atten_v2`to PyTorch and DP (#3840) Solve #3831 and #3139 - add `se_atten_v2` to PyTorch and DP - add document equation for `se_attn_v2` ## Summary by CodeRabbit - **New Features** - Introduced a new descriptor class with enhanced configuration options and methods for serialization and deserialization. - Added new configurable parameters to the descriptor setup for improved flexibility. - **Documentation** - Updated function documentation to reflect new arguments and usage instructions. - **Bug Fixes** - Refined serialization logic to handle new parameters and class types more accurately. - Improved error messages for better clarity during serialization processes. --------- Signed-off-by: Chenqqian Zhang <100290172+Chengqian-Zhang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Duo <50307526+iProzd@users.noreply.github.com> --- deepmd/dpmodel/descriptor/__init__.py | 4 + deepmd/dpmodel/descriptor/se_atten_v2.py | 180 +++++++++++ deepmd/pt/model/descriptor/__init__.py | 4 + deepmd/pt/model/descriptor/se_atten_v2.py | 253 +++++++++++++++ deepmd/tf/descriptor/se_atten.py | 12 +- deepmd/tf/descriptor/se_atten_v2.py | 69 ++++ deepmd/utils/argcheck.py | 66 +++- doc/model/train-se-atten.md | 2 +- .../consistent/descriptor/test_se_atten_v2.py | 300 ++++++++++++++++++ source/tests/pt/model/test_se_atten_v2.py | 141 ++++++++ 10 files changed, 1021 insertions(+), 10 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/se_atten_v2.py create mode 100644 deepmd/pt/model/descriptor/se_atten_v2.py create mode 100644 source/tests/consistent/descriptor/test_se_atten_v2.py create mode 100644 source/tests/pt/model/test_se_atten_v2.py diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index 1a7b376a36..5c3987e1c5 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -11,6 +11,9 @@ from .make_base_descriptor import ( make_base_descriptor, ) +from .se_atten_v2 import ( + DescrptSeAttenV2, +) from .se_e2_a import ( DescrptSeA, ) @@ -26,6 +29,7 @@ "DescrptSeR", "DescrptSeT", "DescrptDPA1", + "DescrptSeAttenV2", "DescrptDPA2", "DescrptHybrid", "make_base_descriptor", diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py new file mode 100644 index 0000000000..1375d2265f --- /dev/null +++ b/deepmd/dpmodel/descriptor/se_atten_v2.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + List, + Optional, + Tuple, + Union, +) + +import numpy as np + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.dpmodel.utils import ( + NetworkCollection, +) +from deepmd.dpmodel.utils.type_embed import ( + TypeEmbedNet, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .dpa1 import ( + DescrptDPA1, + NeighborGatedAttention, +) + + +@BaseDescriptor.register("se_atten_v2") +class DescrptSeAttenV2(DescrptDPA1): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = False, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + stripped_type_embedding: Optional[bool] = None, + use_econf_tebd: bool = False, + type_map: Optional[List[str]] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + ) -> None: + DescrptDPA1.__init__( + self, + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode="strip", + resnet_dt=resnet_dt, + trainable=trainable, + type_one_side=type_one_side, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=attn_mask, + exclude_types=exclude_types, + env_protection=env_protection, + set_davg_zero=set_davg_zero, + activation_function=activation_function, + precision=precision, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth_type_embedding=True, + concat_output_tebd=concat_output_tebd, + spin=spin, + stripped_type_embedding=stripped_type_embedding, + use_econf_tebd=use_econf_tebd, + type_map=type_map, + # consistent with argcheck, not used though + seed=seed, + ) + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self.se_atten + data = { + "@class": "Descriptor", + "type": "se_atten_v2", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + "use_econf_tebd": self.use_econf_tebd, + "type_map": self.type_map, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "embeddings_strip": obj.embeddings_strip.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": obj.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"], + "dstd": obj["dstd"], + }, + ## to be updated when the options are supported. + "trainable": self.trainable, + "spin": None, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeAttenV2": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + data.pop("env_mat") + embeddings_strip = data.pop("embeddings_strip") + obj = cls(**data) + + obj.se_atten["davg"] = variables["davg"] + obj.se_atten["dstd"] = variables["dstd"] + obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings) + obj.se_atten.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) + obj.type_embedding = TypeEmbedNet.deserialize(type_embedding) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index e5298ba3ef..b42aa98380 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -29,6 +29,9 @@ DescrptBlockSeA, DescrptSeA, ) +from .se_atten_v2 import ( + DescrptSeAttenV2, +) from .se_r import ( DescrptSeR, ) @@ -42,6 +45,7 @@ "make_default_type_embedding", "DescrptBlockSeA", "DescrptBlockSeAtten", + "DescrptSeAttenV2", "DescrptSeA", "DescrptSeR", "DescrptSeT", diff --git a/deepmd/pt/model/descriptor/se_atten_v2.py b/deepmd/pt/model/descriptor/se_atten_v2.py new file mode 100644 index 0000000000..3b350ded98 --- /dev/null +++ b/deepmd/pt/model/descriptor/se_atten_v2.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Optional, + Tuple, + Union, +) + +import torch + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.model.network.mlp import ( + NetworkCollection, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + RESERVED_PRECISON_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .se_atten import ( + NeighborGatedAttention, +) + + +@BaseDescriptor.register("se_atten_v2") +class DescrptSeAttenV2(DescrptDPA1): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[List[int], int], + ntypes: int, + neuron: list = [25, 50, 100], + axis_neuron: int = 16, + tebd_dim: int = 8, + set_davg_zero: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], + env_protection: float = 0.0, + scaling_factor: int = 1.0, + normalize=True, + temperature=None, + concat_output_tebd: bool = True, + trainable: bool = True, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + type_one_side: bool = False, + stripped_type_embedding: Optional[bool] = None, + seed: Optional[int] = None, + use_econf_tebd: bool = False, + type_map: Optional[List[str]] = None, + # not implemented + spin=None, + type: Optional[str] = None, + old_impl: bool = False, + ) -> None: + r"""Construct smooth version of embedding net of type `se_atten_v2`. + + Parameters + ---------- + rcut : float + The cut-off radius :math:`r_c` + rcut_smth : float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron : int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim : int + Dimension of the type embedding + set_davg_zero : bool + Set the shift of embedding net input to zero. + attn : int + Hidden dimension of the attention vectors + attn_layer : int + Number of attention layers + attn_dotr : bool + If dot the angular gate to the attention weights + attn_mask : bool + (Only support False to keep consistent with other backend references.) + (Not used in this version.) + If mask the diagonal of attention weights + activation_function : str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision : str + The precision of the embedding net parameters. Supported options are |PRECISION| + resnet_dt : bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float + Protection parameter to prevent division by zero errors during environment matrix calculations. + scaling_factor : float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize : bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature : float + If not None, the scaling of attention weights is `temperature` itself. + trainable_ln : bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, Optional + The epsilon value for layer normalization. + type_one_side : bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + seed : int, Optional + Random seed for parameter initialization. + """ + DescrptDPA1.__init__( + self, + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode="strip", + set_davg_zero=set_davg_zero, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=attn_mask, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + exclude_types=exclude_types, + env_protection=env_protection, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + concat_output_tebd=concat_output_tebd, + trainable=trainable, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth_type_embedding=True, + type_one_side=type_one_side, + stripped_type_embedding=stripped_type_embedding, + seed=seed, + use_econf_tebd=use_econf_tebd, + type_map=type_map, + # not implemented + spin=spin, + type=type, + old_impl=old_impl, + ) + + def serialize(self) -> dict: + obj = self.se_atten + data = { + "@class": "Descriptor", + "type": "se_atten_v2", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn_dim, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + "use_econf_tebd": self.use_econf_tebd, + "type_map": self.type_map, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "embeddings_strip": obj.filter_layers_strip.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + "trainable": self.trainable, + "spin": None, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeAttenV2": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + data.pop("env_mat") + embeddings_strip = data.pop("embeddings_strip") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_atten["davg"] = t_cvt(variables["davg"]) + obj.se_atten["dstd"] = t_cvt(variables["dstd"]) + obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + obj.se_atten.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index b240f00647..2bfe71fcf8 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -1878,12 +1878,8 @@ def serialize(self, suffix: str = "") -> dict: dict The serialized data """ - if type(self) not in [DescrptSeAtten, DescrptDPA1Compat]: - raise NotImplementedError( - f"Not implemented in class {self.__class__.__name__}" - ) - if self.stripped_type_embedding and type(self) is not DescrptDPA1Compat: - # only DescrptDPA1Compat can serialize when tebd_input_mode=='strip' + if self.stripped_type_embedding and type(self) is DescrptSeAtten: + # only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip' raise NotImplementedError( "serialization is unsupported by the native model when tebd_input_mode=='strip'" ) @@ -1963,8 +1959,8 @@ def serialize(self, suffix: str = "") -> dict: } if self.tebd_input_mode in ["strip"]: assert ( - type(self) is DescrptDPA1Compat - ), "only DescrptDPA1Compat can serialize when tebd_input_mode=='strip'" + type(self) is not DescrptSeAtten + ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'" data.update( { "embeddings_strip": self.serialize_network_strip( diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 6204f27855..a4fdf24a55 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -5,6 +5,10 @@ Optional, ) +from deepmd.utils.version import ( + check_version_compatibility, +) + from .descriptor import ( Descriptor, ) @@ -109,3 +113,68 @@ def __init__( smooth_type_embedding=True, **kwargs, ) + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeAttenV2: + raise NotImplementedError(f"Not implemented in class {cls.__name__}") + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + attention_layer_variables = cls.deserialize_attention_layers( + data.pop("attention_layers"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + type_one_side = data["type_one_side"] + two_side_embeeding_net_variables = cls.deserialize_network_strip( + data.pop("embeddings_strip"), + suffix=suffix, + type_one_side=type_one_side, + ) + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.attention_layer_variables = attention_layer_variables + descriptor.two_side_embeeding_net_variables = two_side_embeeding_net_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + data = super().serialize(suffix) + data.pop("smooth_type_embedding") + data.pop("tebd_input_mode") + data.update({"type": "se_atten_v2"}) + return data diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index fadec096eb..bbb203eea9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -637,15 +637,79 @@ def descrpt_se_atten_args(): ] -@descrpt_args_plugin.register("se_atten_v2", doc=doc_only_tf_supported) +@descrpt_args_plugin.register("se_atten_v2") def descrpt_se_atten_v2_args(): doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used" + doc_trainable_ln = ( + "Whether to use trainable shift and scale weights in layer normalization." + ) + doc_ln_eps = "The epsilon value for layer normalization. The default value for TensorFlow is set to 1e-3 to keep consistent with keras while set to 1e-5 in PyTorch and DP implementation." + doc_tebd_dim = "The dimension of atom type embedding." + doc_use_econf_tebd = r"Whether to use electronic configuration type embedding. For TensorFlow backend, please set `use_econf_tebd` in `type_embedding` block instead." + doc_temperature = "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K)." + doc_scaling_factor = ( + "The scaling factor of normalization in calculations of attention weights, which is used to scale the matmul(Q, K). " + "If `temperature` is None, the scaling of attention weights is (N_hidden_dim * scaling_factor)**0.5. " + "Else, the scaling of attention weights is setting to `temperature`." + ) + doc_normalize = ( + "Whether to normalize the hidden vectors during attention calculation." + ) + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." + ) return [ *descrpt_se_atten_common_args(), Argument( "set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero ), + Argument( + "trainable_ln", bool, optional=True, default=True, doc=doc_trainable_ln + ), + Argument("ln_eps", float, optional=True, default=None, doc=doc_ln_eps), + # pt only + Argument( + "tebd_dim", + int, + optional=True, + default=8, + doc=doc_only_pt_supported + doc_tebd_dim, + ), + Argument( + "use_econf_tebd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_econf_tebd, + ), + Argument( + "scaling_factor", + float, + optional=True, + default=1.0, + doc=doc_only_pt_supported + doc_scaling_factor, + ), + Argument( + "normalize", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_normalize, + ), + Argument( + "temperature", + float, + optional=True, + doc=doc_only_pt_supported + doc_temperature, + ), + Argument( + "concat_output_tebd", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_concat_output_tebd, + ), ] diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index acd1a500a7..24950d9595 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -126,7 +126,7 @@ We highly recommend using the version 2.0 of the attention-based descriptor `"se "set_davg_zero": false ``` -When using PyTorch backend, you must continue to use descriptor `"se_atten"` and specify `tebd_input_mode` as `"strip"` and `smooth_type_embedding` as `"true"`, which achieves the effect of `"se_atten_v2"`. The `tebd_input_mode` can take `"concat"` and `"strip"` as values. When using TensorFlow backend, you need to use descriptor `"se_atten_v2"` and do not need to set `tebd_input_mode` and `smooth_type_embedding` because the default value of `tebd_input_mode` is `"strip"`, and the default value of `smooth_type_embedding` is `"true"` in TensorFlow backend. When `tebd_input_mode` is set to `"strip"`, the embedding matrix $\mathcal{G}^i$ is constructed as: +You can use descriptor `"se_atten_v2"` and do not need to set `tebd_input_mode` and `smooth_type_embedding`. In `"se_atten_v2"`, `tebd_input_mode` is forced to be `"strip"` and `smooth_type_embedding` is forced to be `"true"`. When `tebd_input_mode` is `"strip"`, the embedding matrix $\mathcal{G}^i$ is constructed as: ```math (\mathcal{G}^i)_j = \mathcal{N}_{e,2}(s(r_{ij})) + \mathcal{N}_{e,2}(s(r_{ij})) \odot ({N}_{e,2}(\{\mathcal{A}^i, \mathcal{A}^j\}) \odot s(r_{ij})) \quad \mathrm{or} diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py new file mode 100644 index 0000000000..54f3cb5826 --- /dev/null +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Optional, + Tuple, +) + +import numpy as np +from dargs import ( + Argument, +) + +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.se_atten_v2 import ( + DescrptSeAttenV2 as DescrptSeAttenV2PT, + ) +else: + DescrptSeAttenV2PT = None +DescrptSeAttenV2TF = None +from deepmd.utils.argcheck import ( + descrpt_se_atten_args, +) + + +@parameterized( + (4,), # tebd_dim + (True,), # resnet_dt + (True, False), # type_one_side + (20,), # attn + (0, 2), # attn_layer + (True, False), # attn_dotr + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (1.0,), # scaling_factor + (True, False), # normalize + (None, 1.0), # temperature + (1e-5,), # ln_eps + (True,), # concat_output_tebd + ("float64",), # precision + (True, False), # use_econf_tebd +) +class TestSeAttenV2(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + ) = self.param + return { + "sel": [10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "ntypes": self.ntypes, + "axis_neuron": 3, + "tebd_dim": tebd_dim, + "attn": attn, + "attn_layer": attn_layer, + "attn_dotr": attn_dotr, + "attn_mask": False, + "scaling_factor": scaling_factor, + "normalize": normalize, + "temperature": temperature, + "ln_eps": ln_eps, + "concat_output_tebd": concat_output_tebd, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "use_econf_tebd": use_econf_tebd, + "type_map": ["O", "H"] if use_econf_tebd else None, + "seed": 1145141919810, + } + + def is_meaningless_zero_attention_layer_tests( + self, + attn_layer: int, + attn_dotr: bool, + normalize: bool, + temperature: Optional[float], + ) -> bool: + return attn_layer == 0 and (attn_dotr or normalize or temperature is not None) + + @property + def skip_pt(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + ) = self.param + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_dp(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + ) = self.param + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_tf(self) -> bool: + return True + + tf_class = DescrptSeAttenV2TF + dp_class = DescrptSeAttenV2DP + pt_class = DescrptSeAttenV2PT + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/test_se_atten_v2.py b/source/tests/pt/model/test_se_atten_v2.py new file mode 100644 index 0000000000..caecd0a118 --- /dev/null +++ b/source/tests/pt/model/test_se_atten_v2.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DPDescrptSeAttenV2 +from deepmd.pt.model.descriptor.se_atten_v2 import ( + DescrptSeAttenV2, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptSeAttenV2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, to, prec, ect in itertools.product( + [False, True], # resnet_dt + [False, True], # type_one_side + [ + "float64", + ], # precision + [False, True], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + + # dpa1 new impl + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + type_one_side=to, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + old_impl=False, + ).to(env.DEVICE) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptSeAttenV2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptSeAttenV2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, to, ect in itertools.product( + [ + False, + ], # resnet_dt + [ + "float64", + ], # precision + [ + False, + ], # type_one_side + [False, True], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + # dpa1 new impl + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + type_one_side=to, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + old_impl=False, + ) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + _ = torch.jit.script(dd0)