diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 38f5abf616..9058decc21 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.2 + rev: v0.8.3 hooks: - id: ruff args: ["--fix"] @@ -60,7 +60,7 @@ repos: - id: blacken-docs # C++ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.4 + rev: v19.1.5 hooks: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) diff --git a/deepmd/pd/entrypoints/main.py b/deepmd/pd/entrypoints/main.py index 19653d6ea7..3fa66312e7 100644 --- a/deepmd/pd/entrypoints/main.py +++ b/deepmd/pd/entrypoints/main.py @@ -230,7 +230,7 @@ def train( use_pretrain_script: bool = False, force_load: bool = False, output: str = "out.json", -): +) -> None: log.info("Configuration path: %s", input_file) SummaryPrinter()() with open(input_file) as fin: @@ -321,10 +321,18 @@ def train( # save min_nbor_dist if min_nbor_dist is not None: if not multi_task: - trainer.model.min_nbor_dist = min_nbor_dist + trainer.model.min_nbor_dist = paddle.to_tensor( + min_nbor_dist, + dtype=paddle.float64, + place=DEVICE, + ) else: for model_item in min_nbor_dist: - trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item] + trainer.model[model_item].min_nbor_dist = paddle.to_tensor( + min_nbor_dist[model_item], + dtype=paddle.float64, + place=DEVICE, + ) trainer.run() @@ -332,7 +340,7 @@ def freeze( model: str, output: str = "frozen_model.json", head: Optional[str] = None, -): +) -> None: paddle.set_flags( { "FLAGS_save_cf_stack_op": 1, @@ -383,7 +391,7 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, -): +) -> None: if input_file.endswith(".pd"): old_state_dict = paddle.load(input_file) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) diff --git a/deepmd/pd/loss/ener.py b/deepmd/pd/loss/ener.py index 7c5d848b45..73ad53601a 100644 --- a/deepmd/pd/loss/ener.py +++ b/deepmd/pd/loss/ener.py @@ -10,7 +10,6 @@ TaskLoss, ) from deepmd.pd.utils import ( - decomp, env, ) from deepmd.pd.utils.env import ( @@ -224,10 +223,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if self.relative_f is not None: force_label_3 = force_label.reshape([-1, 3]) - # norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f - norm_f = ( - decomp.norm(force_label_3, axis=1, keepdim=True) + self.relative_f - ) + norm_f = force_label_3.norm(axis=1, keepdim=True) + self.relative_f diff_f_3 = diff_f.reshape([-1, 3]) diff_f_3 = diff_f_3 / norm_f diff_f = diff_f_3.reshape([-1]) diff --git a/deepmd/pd/model/atomic_model/dp_atomic_model.py b/deepmd/pd/model/atomic_model/dp_atomic_model.py index 25a0f89d77..1089b93a68 100644 --- a/deepmd/pd/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pd/model/atomic_model/dp_atomic_model.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy import functools import logging from typing import ( @@ -52,7 +51,7 @@ def __init__( fitting, type_map: list[str], **kwargs, - ): + ) -> None: super().__init__(type_map, **kwargs) ntypes = len(type_map) self.type_map = type_map @@ -201,7 +200,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data) -> "DPAtomicModel": - data = copy.deepcopy(data) + data = data.copy() check_version_compatibility(data.pop("@version", 1), 2, 1) data.pop("@class", None) data.pop("type", None) @@ -212,6 +211,37 @@ def deserialize(cls, data) -> "DPAtomicModel": obj = super().deserialize(data) return obj + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Call descriptor enable_compression(). + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + self.descriptor.enable_compression( + min_nbor_dist, + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ) + def forward_atomic( self, extended_coord, @@ -278,7 +308,7 @@ def compute_or_load_stat( self, sampled_func, stat_file_path: Optional[DPPath] = None, - ): + ) -> None: """ Compute or load the statistics parameters of the model, such as mean and standard deviation of descriptors or the energy bias of the fitting net. diff --git a/deepmd/pd/model/descriptor/__init__.py b/deepmd/pd/model/descriptor/__init__.py index 654643959b..7eaa0df85b 100644 --- a/deepmd/pd/model/descriptor/__init__.py +++ b/deepmd/pd/model/descriptor/__init__.py @@ -5,6 +5,10 @@ from .descriptor import ( DescriptorBlock, ) +from .dpa1 import ( + DescrptBlockSeAtten, + DescrptDPA1, +) from .env_mat import ( prod_env_mat, ) @@ -17,6 +21,8 @@ "BaseDescriptor", "DescriptorBlock", "DescrptBlockSeA", + "DescrptBlockSeAtten", + "DescrptDPA1", "DescrptSeA", "prod_env_mat", ] diff --git a/deepmd/pd/model/descriptor/dpa1.py b/deepmd/pd/model/descriptor/dpa1.py new file mode 100644 index 0000000000..f3f1ea26d6 --- /dev/null +++ b/deepmd/pd/model/descriptor/dpa1.py @@ -0,0 +1,689 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.network.mlp import ( + NetworkCollection, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pd.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .se_atten import ( + DescrptBlockSeAtten, + NeighborGatedAttention, +) + + +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") +class DescrptDPA1(BaseDescriptor, paddle.nn.Layer): + r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model. + + This descriptor, :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut: float + The cut-off radius :math:`r_c` + rcut_smth: float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron: int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + resnet_dt: bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable: bool + If the weights of this descriptors are trainable. + trainable_ln: bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps: float, Optional + The epsilon value for layer normalization. + type_one_side: bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + (Only support False to keep consistent with other backend references.) + (Not used in this version. True option is not implemented.) + If mask the diagonal of attention weights + exclude_types : list[list[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero: bool + Set the shift of embedding net input to zero. + activation_function: str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision: str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize: bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature: float + If not None, the scaling of attention weights is `temperature` itself. + smooth_type_embedding: bool + Whether to use smooth process in attention weights calculation. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + stripped_type_embedding: bool, Optional + (Deprecated, kept only for compatibility.) + Whether to strip the type embedding into a separate embedding network. + Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. + Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. + The default value is `None`, which means the `tebd_input_mode` setting will be used instead. + seed: int, Optional + Random seed for parameter initialization. + use_econf_tebd: bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + spin + (Only support None to keep consistent with other backend references.) + (Not used in this version. Not-none option is not implemented.) + The old implementation of deepspin. + + Limitations + ----------- + The currently implementation will not support the following deprecated features + 1. spin is not None + 2. attn_mask == True + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [25, 50, 100], + axis_neuron: int = 16, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + set_davg_zero: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + scaling_factor: int = 1.0, + normalize=True, + temperature=None, + concat_output_tebd: bool = True, + trainable: bool = True, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + smooth_type_embedding: bool = True, + type_one_side: bool = False, + stripped_type_embedding: Optional[bool] = None, + seed: Optional[Union[int, list[int]]] = None, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + # not implemented + spin=None, + type: Optional[str] = None, + ) -> None: + super().__init__() + # Ensure compatibility with the deprecated stripped_type_embedding option. + if stripped_type_embedding is not None: + # Use the user-set stripped_type_embedding parameter first + tebd_input_mode = "strip" if stripped_type_embedding else "concat" + if spin is not None: + raise NotImplementedError("old implementation of spin is not supported.") + if attn_mask: + raise NotImplementedError( + "old implementation of attn_mask is not supported." + ) + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + + self.tebd_input_mode = tebd_input_mode + + del type, spin, attn_mask + self.se_atten = DescrptBlockSeAtten( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + axis_neuron=axis_neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + attn=attn, + attn_layer=attn_layer, + attn_dotr=attn_dotr, + attn_mask=False, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth_type_embedding, + type_one_side=type_one_side, + exclude_types=exclude_types, + env_protection=env_protection, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + seed=child_seed(seed, 1), + ) + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.compress = False + self.type_embedding = TypeEmbedNet( + ntypes, + tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.prec = PRECISION_DICT[precision] + self.tebd_dim = tebd_dim + self.concat_output_tebd = concat_output_tebd + self.trainable = trainable + # set trainable + for param in self.parameters(): + param.stop_gradient = not trainable + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_atten.get_rcut() + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.se_atten.get_rcut_smth() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_atten.get_nsel() + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.se_atten.get_sel() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.se_atten.get_ntypes() + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_atten.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + return self.se_atten.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return self.se_atten.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.se_atten.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.se_atten.need_sorted_nlist_for_lower() + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.se_atten.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA1 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in both type_embedding and se_atten + if shared_level == 0: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + self.se_atten.share_params(base_class.se_atten, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + # Other shared levels + else: + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + return self.se_atten.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: paddle.Tensor, + stddev: paddle.Tensor, + ) -> None: + """Update mean and stddev for descriptor.""" + self.se_atten.mean = mean + self.se_atten.stddev = stddev + + def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]: + """Get mean and stddev for descriptor.""" + return self.se_atten.mean, self.se_atten.stddev + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + obj = self.se_atten + obj.ntypes = len(type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + obj.reinit_exclude(map_pair_exclude_types(obj.exclude_types, remap_index)) + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + obj, + type_map, + des_with_stat=model_with_new_type_stat.se_atten + if model_with_new_type_stat is not None + else None, + ) + obj["davg"] = obj["davg"][remap_index] + obj["dstd"] = obj["dstd"][remap_index] + + def serialize(self) -> dict: + obj = self.se_atten + data = { + "@class": "Descriptor", + "type": "dpa1", + "@version": 2, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn_dim, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": False, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth_type_embedding": obj.smooth, + "type_one_side": obj.type_one_side, + "concat_output_tebd": self.concat_output_tebd, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": obj["davg"].numpy(), + "dstd": obj["dstd"].numpy(), + }, + "trainable": self.trainable, + "spin": None, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + data = data.copy() + check_version_compatibility(data.pop("@version"), 2, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + # compat with version 1 + if "use_tebd_bias" not in data: + data["use_tebd_bias"] = True + obj = cls(**data) + + def t_cvt(xx): + return paddle.to_tensor(xx, dtype=obj.se_atten.prec).to(device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_atten["davg"] = t_cvt(variables["davg"]) + obj.se_atten["dstd"] = t_cvt(variables["dstd"]) + obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_atten.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize( + attention_layers + ) + return obj + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + raise NotImplementedError("Model compression is not supported in paddle yet.") + + def forward( + self, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + nlist: paddle.Tensor, + mapping: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + del mapping + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.place) + else: + type_embedding = None + g1, g2, h2, rot_mat, sw = self.se_atten( + nlist, + extended_coord, + extended_atype, + g1_ext, + mapping=None, + type_embedding=type_embedding, + ) + if self.concat_output_tebd: + g1 = paddle.concat([g1, g1_inp], axis=-1) + + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION) if g2 is not None else None, + h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, sel = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True + ) + local_jdata_cpy["sel"] = sel[0] + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/pd/model/descriptor/env_mat.py b/deepmd/pd/model/descriptor/env_mat.py index 3a9daec1e8..9b72da0b16 100644 --- a/deepmd/pd/model/descriptor/env_mat.py +++ b/deepmd/pd/model/descriptor/env_mat.py @@ -2,9 +2,6 @@ import paddle -from deepmd.pd.utils import ( - decomp, -) from deepmd.pd.utils.preprocess import ( compute_smooth_weight, ) @@ -27,12 +24,10 @@ def _make_env_mat( nlist = paddle.where(mask, nlist, nall - 1) coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3]) index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3]) - # coord_r = paddle.take_along_axis(coord, axis=1, indices=index) - coord_r = decomp.take_along_axis(coord, axis=1, indices=index) + coord_r = paddle.take_along_axis(coord, axis=1, indices=index) coord_r = coord_r.reshape([bsz, natoms, nnei, 3]) diff = coord_r - coord_l - # length = paddle.linalg.norm(diff, axis=-1, keepdim=True) - length = decomp.norm(diff, axis=-1, keepdim=True) + length = paddle.linalg.norm(diff, axis=-1, keepdim=True) # for index 0 nloc atom length = length + (~mask.unsqueeze(-1)).astype(length.dtype) t0 = 1 / (length + protection) diff --git a/deepmd/pd/model/descriptor/se_a.py b/deepmd/pd/model/descriptor/se_a.py index 180d6f0a3f..0af6d082b8 100644 --- a/deepmd/pd/model/descriptor/se_a.py +++ b/deepmd/pd/model/descriptor/se_a.py @@ -9,6 +9,7 @@ import numpy as np import paddle +import paddle.nn as nn from deepmd.dpmodel.utils.seed import ( child_seed, @@ -87,13 +88,14 @@ def __init__( type_map: Optional[list[str]] = None, # not implemented spin=None, - ): + ) -> None: del ntypes if spin is not None: raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map self.compress = False + self.prec = PRECISION_DICT[precision] self.sea = DescrptBlockSeA( rcut, rcut_smth, @@ -161,7 +163,7 @@ def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.sea.get_env_protection() - def share_params(self, base_class, shared_level, resume=False): + def share_params(self, base_class, shared_level, resume=False) -> None: """ Share the parameters of self to the base_class with shared_level during multitask training. If not start from checkpoint (resume is False), @@ -222,10 +224,35 @@ def compute_input_stats( def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], - ): + ) -> None: """Update the type exclusions.""" self.sea.reinit_exclude(exclude_types) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise ValueError("Enable compression is not supported.") + def forward( self, coord_ext: paddle.Tensor, @@ -266,7 +293,18 @@ def forward( The smooth switch function. """ - return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping) + # cast the input to internal precsion + coord_ext = coord_ext.to(dtype=self.prec) + g1, rot_mat, g2, h2, sw = self.sea.forward( + nlist, coord_ext, atype_ext, None, mapping + ) + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + None, + None, + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) def set_stat_mean_and_stddev( self, @@ -367,10 +405,6 @@ def update_sel( class DescrptBlockSeA(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] - lower: dict[str, int] - upper: dict[str, int] - table_data: dict[str, paddle.Tensor] - table_config: list[Union[int, float]] def __init__( self, @@ -389,7 +423,7 @@ def __init__( trainable: bool = True, seed: Optional[Union[int, list[int]]] = None, **kwargs, - ): + ) -> None: """Construct an embedding net of type `se_a`. Args: @@ -430,13 +464,6 @@ def __init__( self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) - # add for compression - self.compress = False - self.lower = {} - self.upper = {} - self.table_data = {} - self.table_config = [] - ndim = 1 if self.type_one_side else 2 filter_layers = NetworkCollection( ndim=ndim, ntypes=len(sel), network_type="embedding_network" @@ -459,6 +486,21 @@ def __init__( for param in self.parameters(): param.stop_gradient = not trainable + # add for compression + self.compress = False + self.compress_info = nn.ParameterList( + [ + self.create_parameter([], dtype=self.prec).to(device="cpu") + for _ in range(len(self.filter_layers.networks)) + ] + ) + self.compress_data = nn.ParameterList( + [ + self.create_parameter([], dtype=self.prec).to(device=env.DEVICE) + for _ in range(len(self.filter_layers.networks)) + ] + ) + def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut @@ -517,11 +559,11 @@ def dim_out(self): return self.filter_neuron[-1] * self.axis_neuron @property - def dim_in(self): + def dim_in(self) -> int: """Returns the atomic input dimension of this descriptor.""" return 0 - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: if key in ("avg", "data_avg", "davg"): self.mean = value elif key in ("std", "data_std", "dstd"): @@ -541,7 +583,7 @@ def compute_input_stats( self, merged: Union[Callable[[], list[dict]], list[dict]], path: Optional[DPPath] = None, - ): + ) -> None: """ Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. @@ -587,22 +629,45 @@ def get_stats(self) -> dict[str, StatItem]: def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], - ): + ) -> None: self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) def enable_compression( self, - table_data, - table_config, - lower, - upper, + table_data: dict[str, paddle.Tensor], + table_config: list[Union[int, float]], + lower: dict[str, int], + upper: dict[str, int], ) -> None: + for embedding_idx, ll in enumerate(self.filter_layers.networks): + if self.type_one_side: + ii = embedding_idx + ti = -1 + else: + # ti: center atom type, ii: neighbor type... + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + info_ii = paddle.to_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + place="cpu", + ) + tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress_data[embedding_idx] = tensor_data_ii + self.compress_info[embedding_idx] = info_ii self.compress = True - self.table_data = table_data - self.table_config = table_config - self.lower = lower - self.upper = upper def forward( self, @@ -611,6 +676,7 @@ def forward( extended_atype: paddle.Tensor, extended_atype_embd: Optional[paddle.Tensor] = None, mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, ): """Calculate decoded embedding for each atom. @@ -627,7 +693,7 @@ def forward( del extended_atype_embd, mapping nf = nlist.shape[0] nloc = nlist.shape[1] - atype: paddle.Tensor = extended_atype[:, :nloc] + atype = extended_atype[:, :nloc] dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, @@ -640,7 +706,6 @@ def forward( ) dmatrix = dmatrix.reshape([-1, self.nnei, 4]) - dmatrix = dmatrix.astype(self.prec) nfnl = dmatrix.shape[0] # pre-allocate a shape to pass jit xyz_scatter = paddle.zeros( @@ -649,7 +714,9 @@ def forward( ).to(extended_coord.place) # nfnl x nnei exclude_mask = self.emask(nlist, extended_atype).reshape([nfnl, self.nnei]) - for embedding_idx, ll in enumerate(self.filter_layers.networks): + for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate( + zip(self.filter_layers.networks, self.compress_data, self.compress_info) + ): if self.type_one_side: ii = embedding_idx ti = -1 @@ -680,10 +747,16 @@ def forward( if rr.numel() > 0: rr = rr * mm.unsqueeze(2).astype(rr.dtype) ss = rr[:, :, :1] - # nfnl x nt x ng - gg = ll.forward(ss) - # nfnl x 4 x ng - gr = paddle.matmul(rr.transpose([0, 2, 1]), gg) + if self.compress: + raise NotImplementedError( + "Compressed environment is not implemented yet." + ) + else: + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 4 x ng + gr = paddle.matmul(rr.transpose([0, 2, 1]), gg) + if ti_mask is not None: xyz_scatter[ti_mask] += gr else: @@ -699,8 +772,8 @@ def forward( result = result.reshape([nf, nloc, self.filter_neuron[-1] * self.axis_neuron]) rot_mat = rot_mat.reshape([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005 return ( - result.astype(env.GLOBAL_PD_FLOAT_PRECISION), - rot_mat.astype(env.GLOBAL_PD_FLOAT_PRECISION), + result, + rot_mat, None, None, sw, diff --git a/deepmd/pd/model/descriptor/se_atten.py b/deepmd/pd/model/descriptor/se_atten.py new file mode 100644 index 0000000000..1ebf8c6717 --- /dev/null +++ b/deepmd/pd/model/descriptor/se_atten.py @@ -0,0 +1,1073 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle +import paddle.nn as nn +import paddle.nn.functional as paddle_func + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pd.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pd.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pd.model.network.mlp import ( + EmbeddingNet, + MLPLayer, + NetworkCollection, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pd.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@DescriptorBlock.register("se_atten") +class DescrptBlockSeAtten(DescriptorBlock): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [25, 50, 100], + axis_neuron: int = 16, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + set_davg_zero: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, + scaling_factor=1.0, + normalize=True, + temperature=None, + smooth: bool = True, + type_one_side: bool = False, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + seed: Optional[Union[int, list[int]]] = None, + type: Optional[str] = None, + ) -> None: + r"""Construct an embedding net of type `se_atten`. + + Parameters + ---------- + rcut : float + The cut-off radius :math:`r_c` + rcut_smth : float + From where the environment matrix should be smoothed :math:`r_s` + sel : list[int], int + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron : int + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim : int + Dimension of the type embedding + tebd_input_mode : str + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output. + resnet_dt : bool + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable_ln : bool + Whether to use trainable shift and scale weights in layer normalization. + ln_eps : float, Optional + The epsilon value for layer normalization. + type_one_side : bool + If 'False', type embeddings of both neighbor and central atoms are considered. + If 'True', only type embeddings of neighbor atoms are considered. + Default is 'False'. + attn : int + Hidden dimension of the attention vectors + attn_layer : int + Number of attention layers + attn_dotr : bool + If dot the angular gate to the attention weights + attn_mask : bool + (Only support False to keep consistent with other backend references.) + (Not used in this version.) + If mask the diagonal of attention weights + exclude_types : list[list[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float + Protection parameter to prevent division by zero errors during environment matrix calculations. + set_davg_zero : bool + Set the shift of embedding net input to zero. + activation_function : str + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision : str + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor : float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + normalize : bool + Whether to normalize the hidden vectors in attention weights calculation. + temperature : float + If not None, the scaling of attention weights is `temperature` itself. + seed : int, Optional + Random seed for parameter initialization. + """ + super().__init__() + del type + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.neuron = neuron + self.filter_neuron = self.neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.set_davg_zero = set_davg_zero + self.attn_dim = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.attn_mask = attn_mask + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.smooth = smooth + self.type_one_side = type_one_side + self.env_protection = env_protection + self.trainable_ln = trainable_ln + self.seed = seed + # to keep consistent with default value in this backends + if ln_eps is None: + ln_eps = 1e-5 + self.ln_eps = ln_eps + + if isinstance(sel, int): + sel = [sel] + + self.ntypes = ntypes + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + + self.dpa1_attention = NeighborGatedAttention( + self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + smooth=self.smooth, + precision=self.precision, + seed=child_seed(self.seed, 0), + ) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = paddle.zeros(wanted_shape, dtype=self.prec).to(device=env.DEVICE) + stddev = paddle.ones(wanted_shape, dtype=self.prec).to(device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2 + if self.tebd_input_mode in ["concat"]: + self.embd_input_dim = 1 + self.tebd_dim_input + else: + self.embd_input_dim = 1 + + self.filter_layers_strip = None + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 1), + ) + self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 2), + ) + self.filter_layers_strip = filter_layers_strip + self.stats = None + + # add for compression + self.compress = False + self.is_sorted = False + self.compress_info = nn.ParameterList( + [ + self.create_parameter( + [], default_initializer=nn.initializer.Constant(0), dtype=self.prec + ).to("cpu") + ] + ) + self.compress_data = nn.ParameterList( + [ + self.create_parameter( + [], default_initializer=nn.initializer.Constant(0), dtype=self.prec + ).to(env.DEVICE) + ] + ) + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] * self.axis_neuron + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim + + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + paddle.assign(paddle.to_tensor(mean).to(device=env.DEVICE), self.mean) # pylint: disable=no-explicit-dtype + paddle.assign(paddle.to_tensor(stddev).to(device=env.DEVICE), self.stddev) # pylint: disable=no-explicit-dtype + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.is_sorted = len(self.exclude_types) == 0 + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + net = "filter_net" + self.compress_info[0] = paddle.to_tensor( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=self.prec, + place="cpu", + ) + self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec) + self.compress = True + + def forward( + self, + nlist: paddle.Tensor, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + extended_atype_embd: Optional[paddle.Tensor] = None, + mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall x nt + extended_atype_embd + The extended type embedding of atoms. shape: nf x nall + mapping + The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. + + Returns + ------- + result + The descriptor. shape: nf x nloc x (ng x axis_neuron) + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + del mapping + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + nb = nframes + nall = extended_coord.reshape([nb, -1, 3]).shape[1] + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) + nlist_mask = nlist != -1 + nlist = paddle.where(nlist == -1, paddle.zeros_like(nlist), nlist) + sw = paddle.squeeze(sw, -1) + # nf x nall x nt + nt = extended_atype_embd.shape[-1] + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + # (nb x nloc) x nnei + exclude_mask = exclude_mask.reshape([nb * nloc, nnei]) + # nfnl x nnei x 4 + dmatrix = dmatrix.reshape([-1, self.nnei, 4]) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None].astype(rr.dtype) + ss = rr[:, :, :1] + if self.tebd_input_mode in ["concat"]: + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, nt]) + # nb x (nloc x nnei) x nt + atype_tebd_nlist = paddle.take_along_axis( + atype_tebd_ext, axis=1, indices=index + ) # j + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.reshape([nb, nloc, nnei, nt]) + + # nf x nloc x nt -> nf x nloc x nnei x nt + atype_tebd = extended_atype_embd[:, :nloc, :] + atype_tebd_nnei = atype_tebd.unsqueeze(2).expand( + [-1, -1, self.nnei, -1] + ) # i + + nlist_tebd = atype_tebd_nlist.reshape([nfnl, nnei, self.tebd_dim]) + atype_tebd = atype_tebd_nnei.reshape([nfnl, nnei, self.tebd_dim]) + if not self.type_one_side: + # nfnl x nnei x (1 + tebd_dim * 2) + ss = paddle.concat([ss, nlist_tebd, atype_tebd], axis=2) + else: + # nfnl x nnei x (1 + tebd_dim) + ss = paddle.concat([ss, nlist_tebd], axis=2) + # nfnl x nnei x ng + gg = self.filter_layers.networks[0](ss) + input_r = paddle.nn.functional.normalize( + rr.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = paddle.matmul(rr.transpose([0, 2, 1]), gg) + elif self.tebd_input_mode in ["strip"]: + assert self.filter_layers_strip is not None + assert type_embedding is not None + ng = self.filter_neuron[-1] + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = nlist.reshape([nb, nloc * nnei]) + # nf x (nl x nnei) + nei_type = paddle.take_along_axis( + extended_atype, indices=nlist_index, axis=1 + ) + # (nf x nl x nnei) x ng + nei_type_index = nei_type.reshape([-1, 1]).expand([-1, ng]).to(paddle.int64) + if self.type_one_side: + tt_full = self.filter_layers_strip.networks[0](type_embedding) + # (nf x nl x nnei) x ng + gg_t = paddle.take_along_axis(tt_full, indices=nei_type_index, axis=0) + else: + idx_i = paddle.tile( + atype.reshape([-1, 1]) * ntypes_with_padding, [1, nnei] + ).reshape([-1]) + idx_j = nei_type.reshape([-1]) + # (nf x nl x nnei) x ng + idx = (idx_i + idx_j).reshape([-1, 1]).expand([-1, ng]).to(paddle.int64) + # (ntypes) * ntypes * nt + type_embedding_nei = paddle.tile( + type_embedding.reshape([1, ntypes_with_padding, nt]), + [ntypes_with_padding, 1, 1], + ) + # ntypes * (ntypes) * nt + type_embedding_center = paddle.tile( + type_embedding.reshape([ntypes_with_padding, 1, nt]), + [1, ntypes_with_padding, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = paddle.concat( + [type_embedding_nei, type_embedding_center], -1 + ).reshape([-1, nt * 2]) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + # (nf x nl x nnei) x ng + gg_t = paddle.take_along_axis(tt_full, axis=0, indices=idx) + # (nf x nl) x nnei x ng + gg_t = gg_t.reshape([nfnl, nnei, ng]) + if self.smooth: + gg_t = gg_t * sw.reshape([-1, self.nnei, 1]) + if self.compress: + raise NotImplementedError("Compression is not implemented yet.") + else: + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s + input_r = paddle_func.normalize( + rr.reshape([-1, self.nnei, 4])[:, :, 1:4], axis=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = paddle.matmul(rr.transpose([0, 2, 1]), gg) + else: + raise NotImplementedError + + xyz_scatter = xyz_scatter / self.nnei + xyz_scatter_1 = xyz_scatter.transpose([0, 2, 1]) + rot_mat = xyz_scatter_1[:, :, 1:4] + xyz_scatter_2 = xyz_scatter[:, :, 0 : self.axis_neuron] + result = paddle.matmul( + xyz_scatter_1, xyz_scatter_2 + ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] + + return ( + result.reshape([nframes, nloc, self.filter_neuron[-1] * self.axis_neuron]), + gg.reshape([nframes, nloc, self.nnei, self.filter_neuron[-1]]) + if not self.compress + else None, + dmatrix.reshape([nframes, nloc, self.nnei, 4])[..., 1:], + rot_mat.reshape([nframes, nloc, self.filter_neuron[-1], 3]), + sw, + ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False + + +class NeighborGatedAttention(nn.Layer): + def __init__( + self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + """Construct a neighbor-wise attention net.""" + super().__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.smooth = smooth + self.precision = precision + self.seed = seed + self.network_type = NeighborGatedAttentionLayer + attention_layers = [] + for i in range(self.layer_num): + attention_layers.append( + NeighborGatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + trainable_ln=trainable_ln, + ln_eps=ln_eps, + smooth=smooth, + precision=precision, + seed=child_seed(seed, i), + ) + ) + self.attention_layers = nn.LayerList(attention_layers) + + def forward( + self, + input_G, + nei_mask, + input_r: Optional[paddle.Tensor] = None, + sw: Optional[paddle.Tensor] = None, + ): + """Compute the multi-layer gated self-attention. + + Parameters + ---------- + input_G + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: nf x nloc x nnei + """ + out = input_G + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def __getitem__(self, key): + if isinstance(key, int): + return self.attention_layers[key] + else: + raise TypeError(key) + + def __setitem__(self, key, value) -> None: + if not isinstance(key, int): + raise TypeError(key) + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[key] = value + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "NeighborGatedAttention", + "@version": 1, + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + attention_layers = data.pop("attention_layers") + obj = cls(**data) + for ii, network in enumerate(attention_layers): + obj[ii] = network + return obj + + +class NeighborGatedAttentionLayer(nn.Layer): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + smooth: bool = True, + trainable_ln: bool = True, + ln_eps: float = 1e-5, + precision: str = DEFAULT_PRECISION, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + """Construct a neighbor-wise attention layer.""" + super().__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.seed = seed + self.attention_layer = GatedAttentionLayer( + nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + smooth=smooth, + precision=precision, + seed=child_seed(seed, 0), + ) + self.attn_layer_norm = LayerNorm( + self.embed_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 1), + ) + + def forward( + self, + x, + nei_mask, + input_r: Optional[paddle.Tensor] = None, + sw: Optional[paddle.Tensor] = None, + ): + residual = x + x, _ = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "trainable_ln": self.trainable_ln, + "ln_eps": self.ln_eps, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(nn.Layer): + def __init__( + self, + nnei: int, + embed_dim: int, + hidden_dim: int, + num_heads: int = 1, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: Optional[float] = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + """Construct a multi-head neighbor-wise attention net.""" + super().__init__() + assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads" + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + self.seed = seed + self.scaling = ( + (self.head_dim * scaling_factor) ** -0.5 + if temperature is None + else temperature + ) + self.normalize = normalize + self.in_proj = MLPLayer( + embed_dim, + hidden_dim * 3, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + seed=child_seed(seed, 0), + ) + self.out_proj = MLPLayer( + hidden_dim, + embed_dim, + bias=bias, + use_timestep=False, + bavg=0.0, + stddev=1.0, + precision=precision, + seed=child_seed(seed, 1), + ) + + def forward( + self, + query, + nei_mask, + input_r: Optional[paddle.Tensor] = None, + sw: Optional[paddle.Tensor] = None, + attnw_shift: float = 20.0, + ): + """Compute the multi-head gated self-attention. + + Parameters + ---------- + query + inputs with shape: (nf x nloc) x nnei x embed_dim. + nei_mask + neighbor mask, with paddings being 0. shape: (nf x nloc) x nnei. + input_r + normalized radial. shape: (nf x nloc) x nnei x 3. + sw + The smooth switch function. shape: (nf x nloc) x nnei + attnw_shift : float + The attention weight shift to preserve smoothness when doing padding before softmax. + """ + q, k, v = self.in_proj(query).chunk(3, axis=-1) + + # Reshape for multi-head attention: (nf x nloc) x num_heads x nnei x head_dim + q = q.reshape([-1, self.nnei, self.num_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + k = k.reshape([-1, self.nnei, self.num_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + v = v.reshape([-1, self.nnei, self.num_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + + if self.normalize: + q = paddle_func.normalize(q, axis=-1) + k = paddle_func.normalize(k, axis=-1) + v = paddle_func.normalize(v, axis=-1) + + q = q * self.scaling + # (nf x nloc) x num_heads x head_dim x nnei + k = k.transpose([0, 1, 3, 2]) + + # Compute attention scores + # (nf x nloc) x num_heads x nnei x nnei + attn_weights = paddle.matmul(q, k) + # (nf x nloc) x nnei + nei_mask = nei_mask.reshape([-1, self.nnei]) + + if self.smooth: + assert sw is not None + # (nf x nloc) x 1 x nnei + sw = sw.reshape([-1, 1, self.nnei]) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ + :, :, None, : + ] - attnw_shift + else: + # (nf x nloc) x 1 x 1 x nnei + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1).unsqueeze(1), float("-inf") + ) + + attn_weights = paddle_func.softmax(attn_weights, axis=-1) + attn_weights = attn_weights.masked_fill( + ~nei_mask.unsqueeze(1).unsqueeze(-1), 0.0 + ) + if self.smooth: + assert sw is not None + attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] + + if self.dotr: + # (nf x nloc) x nnei x 3 + assert input_r is not None, "input_r must be provided when dotr is True!" + # (nf x nloc) x 1 x nnei x nnei + angular_weight = paddle.matmul( + input_r, input_r.transpose([0, 2, 1]) + ).reshape([-1, 1, self.nnei, self.nnei]) + attn_weights = attn_weights * angular_weight + + # Apply attention to values + # (nf x nloc) x nnei x (num_heads x head_dim) + o = ( + paddle.matmul(attn_weights, v) + .transpose([0, 2, 1, 3]) + .reshape([-1, self.nnei, self.hidden_dim]) + ) + output = self.out_proj(o) + return output, attn_weights + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "num_heads": self.num_heads, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "GatedAttentionLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = MLPLayer.deserialize(in_proj) + obj.out_proj = MLPLayer.deserialize(out_proj) + return obj diff --git a/deepmd/pd/model/model/ener_model.py b/deepmd/pd/model/model/ener_model.py index 3f3db4a527..a5b1b9d4b3 100644 --- a/deepmd/pd/model/model/ener_model.py +++ b/deepmd/pd/model/model/ener_model.py @@ -1,7 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from copy import ( - deepcopy, -) from typing import ( Optional, ) @@ -33,26 +30,26 @@ def __init__( self, *args, **kwargs, - ): + ) -> None: DPModelCommon.__init__(self) DPEnergyModel_.__init__(self, *args, **kwargs) def translated_output_def(self): out_def_data = self.model_output_def().get_data() output_def = { - "atom_energy": deepcopy(out_def_data["energy"]), - "energy": deepcopy(out_def_data["energy_redu"]), + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], } if self.do_grad_r("energy"): - output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"] = out_def_data["energy_derv_r"] output_def["force"].squeeze(-2) if self.do_grad_c("energy"): - output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) - output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"] = out_def_data["energy_derv_c"] output_def["atom_virial"].squeeze(-3) if "mask" in out_def_data: - output_def["mask"] = deepcopy(out_def_data["mask"]) + output_def["mask"] = out_def_data["mask"] return output_def def forward( diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index d5c5c6bd41..2b9a4b5bec 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -24,9 +24,6 @@ communicate_extended_output, fit_output_to_model_output, ) -from deepmd.pd.utils import ( - decomp, -) from deepmd.pd.utils.env import ( GLOBAL_PD_ENER_FLOAT_PRECISION, GLOBAL_PD_FLOAT_PRECISION, @@ -72,7 +69,7 @@ def __init__( # underscore to prevent conflict with normal inputs atomic_model_: Optional[T_AtomicModel] = None, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) if atomic_model_ is not None: self.atomic_model: T_AtomicModel = atomic_model_ @@ -176,7 +173,9 @@ def forward_common( atype, self.get_rcut(), self.get_sel(), - mixed_types=self.mixed_types(), + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + mixed_types=True, box=bb, ) model_predict_lower = self.forward_common_lower( @@ -411,7 +410,7 @@ def format_nlist( Returns ------- - formatted_nlist + formated_nlist the formatted nlist. """ @@ -459,18 +458,17 @@ def _format_nlist( coord0 = extended_coord[:, :n_nloc, :] # nf x (nloc x nnei) x 3 index = nlist.reshape([n_nf, n_nloc * n_nnei, 1]).expand([-1, -1, 3]) - coord1 = decomp.take_along_axis(extended_coord, axis=1, indices=index) + coord1 = paddle.take_along_axis(extended_coord, axis=1, indices=index) # nf x nloc x nnei x 3 coord1 = coord1.reshape([n_nf, n_nloc, n_nnei, 3]) # nf x nloc x nnei - # rr = paddle.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = decomp.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = paddle.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) rr = paddle.where(m_real_nei, rr, float("inf")) rr, nlist_mapping = ( paddle.sort(rr, axis=-1), paddle.argsort(rr, axis=-1), ) - nlist = decomp.take_along_axis(nlist, axis=2, indices=nlist_mapping) + nlist = paddle.take_along_axis(nlist, axis=2, indices=nlist_mapping) nlist = paddle.where(rr > rcut, paddle.full_like(nlist, -1), nlist) nlist = nlist[..., :nnei] else: # not extra_nlist_sort and n_nnei <= nnei: diff --git a/deepmd/pd/model/network/layernorm.py b/deepmd/pd/model/network/layernorm.py new file mode 100644 index 0000000000..4d37b208f9 --- /dev/null +++ b/deepmd/pd/model/network/layernorm.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import numpy as np +import paddle +import paddle.nn as nn + +from deepmd.dpmodel.utils.network import LayerNorm as DPLayerNorm +from deepmd.pd.model.network.init import ( + normal_, + ones_, + zeros_, +) +from deepmd.pd.utils import ( + decomp, + env, +) +from deepmd.pd.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pd.utils.utils import ( + get_generator, + to_numpy_array, + to_paddle_tensor, +) + +device = env.DEVICE + + +def empty_t(shape, precision): + return paddle.empty(shape, dtype=precision).to(device=device) + + +class LayerNorm(nn.Layer): + def __init__( + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + bavg: float = 0.0, + stddev: float = 1.0, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ): + super().__init__() + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.matrix = self.create_parameter( + shape=[num_in], + dtype=self.prec, + default_initializer=nn.initializer.Assign( + empty_t((num_in,), self.prec), + ), + ) + self.bias = self.create_parameter( + shape=[num_in], + dtype=self.prec, + default_initializer=nn.initializer.Assign(empty_t([num_in], self.prec)), + ) + random_generator = get_generator(seed) + if self.uni_init: + ones_(self.matrix.data) + zeros_(self.bias.data) + else: + normal_(self.bias.data, mean=bavg, std=stddev, generator=random_generator) + normal_( + self.matrix.data, + std=stddev / np.sqrt(self.num_in), + generator=random_generator, + ) + self.trainable = trainable + if not self.trainable: + self.matrix.stop_gradient = True + self.bias.stop_gradient = True + + def dim_out(self) -> int: + return self.matrix.shape[0] + + def forward( + self, + xx: paddle.Tensor, + ) -> paddle.Tensor: + """One Layer Norm used by DP model. + + Parameters + ---------- + xx : paddle.Tensor + The input of index. + + Returns + ------- + yy: paddle.Tensor + The output. + """ + # if xx.numel() > 0: + if decomp.numel(xx): + variance, mean = ( + paddle.var(xx, axis=-1, unbiased=False, keepdim=True), + paddle.mean(xx, axis=-1, keepdim=True), + ) + yy = (xx - mean) / paddle.sqrt(variance + self.eps) + else: + yy = xx + if self.matrix is not None and self.bias is not None: + yy = yy * self.matrix + self.bias + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = DPLayerNorm( + self.matrix.shape[0], + eps=self.eps, + trainable=self.trainable, + precision=self.precision, + ) + nl.w = to_numpy_array(self.matrix) + nl.b = to_numpy_array(self.bias) + data = nl.serialize() + return data + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + nl = DPLayerNorm.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + eps=nl["eps"], + trainable=nl["trainable"], + precision=nl["precision"], + ) + prec = PRECISION_DICT[obj.precision] + + def check_load_param(ss): + if nl[ss] is not None: + tensor = to_paddle_tensor(nl[ss]) + return paddle.create_parameter( + tensor.shape, + dtype=tensor.dtype, + default_initializer=nn.initializer.Assign(tensor), + ) + return None + + obj.matrix = check_load_param("matrix") + obj.bias = check_load_param("bias") + return obj diff --git a/deepmd/pd/model/network/network.py b/deepmd/pd/model/network/network.py index f118c234ab..1974e526a0 100644 --- a/deepmd/pd/model/network/network.py +++ b/deepmd/pd/model/network/network.py @@ -45,7 +45,7 @@ def __init__( use_econf_tebd=False, use_tebd_bias: bool = False, type_map=None, - ): + ) -> None: """Construct a type embedding net.""" super().__init__() self.type_nums = type_nums @@ -80,11 +80,28 @@ def forward(self, atype): """ return self.embedding(atype.place)[atype] - def share_params(self, base_class, shared_level, resume=False): + def get_full_embedding(self, device: str): + """ + Get the type embeddings of all types. + + Parameters + ---------- + device : str + The device on which to perform the computation. + + Returns + ------- + type_embedding : paddle.Tensor + The full type embeddings of all types. The last index corresponds to the zero padding. + Shape: (ntypes + 1) x tebd_dim + """ + return self.embedding(device) + + def share_params(self, base_class, shared_level, resume=False) -> None: """ Share the parameters of self to the base_class with shared_level during multitask training. If not start from checkpoint (resume is False), - some seperated parameters (e.g. mean and stddev) will be re-calculated across different classes. + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. """ assert ( self.__class__ == base_class.__class__ @@ -148,7 +165,7 @@ def __init__( use_econf_tebd: bool = False, use_tebd_bias: bool = False, type_map: Optional[list[str]] = None, - ): + ) -> None: """Construct a type embedding net.""" super().__init__() self.ntypes = ntypes diff --git a/deepmd/pd/model/task/fitting.py b/deepmd/pd/model/task/fitting.py index 375cf834cc..d9db44aff5 100644 --- a/deepmd/pd/model/task/fitting.py +++ b/deepmd/pd/model/task/fitting.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy import logging from abc import ( abstractmethod, @@ -55,7 +54,7 @@ def __new__(cls, *args, **kwargs): return BaseFitting.__new__(BaseFitting, *args, **kwargs) return super().__new__(cls) - def share_params(self, base_class, shared_level, resume=False): + def share_params(self, base_class, shared_level, resume=False) -> None: """ Share the parameters of self to the base_class with shared_level during multitask training. If not start from checkpoint (resume is False), @@ -65,14 +64,7 @@ def share_params(self, base_class, shared_level, resume=False): self.__class__ == base_class.__class__ ), "Only fitting nets of the same type can share params!" if shared_level == 0: - # link buffers - if hasattr(self, "bias_atom_e"): - self.bias_atom_e = base_class.bias_atom_e - # the following will successfully link all the params except buffers, which need manually link. - for item in self._sub_layers: - self._sub_layers[item] = base_class._sub_layers[item] - elif shared_level == 1: - # only not share the bias_atom_e + # only not share the bias_atom_e and the case_embd # the following will successfully link all the params except buffers, which need manually link. for item in self._sub_layers: self._sub_layers[item] = base_class._sub_layers[item] @@ -104,7 +96,6 @@ class GeneralFitting(Fitting): numb_aparam : int Number of atomic parameters. dim_case_embd : int - (Not supported yet) Dimension of case specific embedding. activation_function : str Activation function. @@ -155,7 +146,7 @@ def __init__( type_map: Optional[list[str]] = None, use_aparam_as_mask: bool = False, **kwargs, - ): + ) -> None: super().__init__() self.var_name = var_name self.ntypes = ntypes @@ -166,9 +157,6 @@ def __init__( self.numb_fparam = numb_fparam self.numb_aparam = numb_aparam self.dim_case_embd = dim_case_embd - if dim_case_embd > 0: - raise ValueError("dim_case_embd is not supported yet in PaddlePaddle.") - self.case_embd = None self.activation_function = activation_function self.precision = precision self.prec = PRECISION_DICT[self.precision] @@ -189,7 +177,9 @@ def __init__( # init constants if bias_atom_e is None: bias_atom_e = np.zeros([self.ntypes, net_dim_out], dtype=np.float64) - bias_atom_e = paddle.to_tensor(bias_atom_e, dtype=self.prec).to(device=device) + bias_atom_e = paddle.to_tensor( + bias_atom_e, dtype=env.GLOBAL_PD_FLOAT_PRECISION, place=device + ) bias_atom_e = bias_atom_e.reshape([self.ntypes, net_dim_out]) if not self.mixed_types: assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" @@ -218,10 +208,20 @@ def __init__( else: self.aparam_avg, self.aparam_inv_std = None, None + if self.dim_case_embd > 0: + self.register_buffer( + "case_embd", + paddle.zeros(self.dim_case_embd, dtype=self.prec, place=device), + # paddle.eye(self.dim_case_embd, dtype=self.prec, place=device)[0], + ) + else: + self.case_embd = None + in_dim = ( self.dim_descrpt + self.numb_fparam + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd ) self.filter_layers = NetworkCollection( @@ -249,7 +249,7 @@ def __init__( def reinit_exclude( self, exclude_types: list[int] = [], - ): + ) -> None: self.exclude_types = exclude_types self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) @@ -299,7 +299,7 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "@variables": { "bias_atom_e": to_numpy_array(self.bias_atom_e), - "case_embd": None, + "case_embd": to_numpy_array(self.case_embd), "fparam_avg": to_numpy_array(self.fparam_avg), "fparam_inv_std": to_numpy_array(self.fparam_inv_std), "aparam_avg": to_numpy_array(self.aparam_avg), @@ -321,7 +321,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "GeneralFitting": - data = copy.deepcopy(data) + data = data.copy() variables = data.pop("@variables") nets = data.pop("nets") obj = cls(**data) @@ -364,9 +364,11 @@ def set_case_embd(self, case_idx: int): Set the case embedding of this fitting net by the given case_idx, typically concatenated with the output of the descriptor and fed into the fitting net. """ - raise NotImplementedError("set_case_embd is not supported yet in PaddlePaddle.") + self.case_embd = paddle.eye(self.dim_case_embd, dtype=self.prec).to(device)[ + case_idx + ] - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: if key in ["bias_atom_e"]: value = value.reshape([self.ntypes, self._net_out_dim()]) self.bias_atom_e = value @@ -424,7 +426,11 @@ def _forward_common( fparam: Optional[paddle.Tensor] = None, aparam: Optional[paddle.Tensor] = None, ): - xx = descriptor + # cast the input to internal precsion + xx = descriptor.to(self.prec) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + if self.remove_vaccum_contribution is not None: # TODO: compute the input for vaccm when remove_vaccum_contribution is set # Ideally, the input for vacuum should be computed; @@ -492,15 +498,30 @@ def _forward_common( axis=-1, ) + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = paddle.tile(self.case_embd.reshape([1, 1, -1]), [nf, nloc, 1]) + xx = paddle.concat( + [xx, case_embd], + axis=-1, + ) + if xx_zeros is not None: + xx_zeros = paddle.concat( + [xx_zeros, case_embd], + axis=-1, + ) + outs = paddle.zeros( (nf, nloc, net_dim_out), dtype=env.GLOBAL_PD_FLOAT_PRECISION, - ).to(device=descriptor.place) # jit assertion + ).to(device=descriptor.place) if self.mixed_types: atom_property = self.filter_layers.networks[0](xx) + self.bias_atom_e[atype] if xx_zeros is not None: atom_property -= self.filter_layers.networks[0](xx_zeros) - outs = outs + atom_property # Shape is [nframes, natoms[0], net_dim_out] + outs = ( + outs + atom_property + self.bias_atom_e[atype].to(self.prec) + ) # Shape is [nframes, natoms[0], net_dim_out] else: for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) @@ -516,12 +537,12 @@ def _forward_common( ): atom_property -= ll(xx_zeros) atom_property = atom_property + self.bias_atom_e[type_i] - atom_property = atom_property * mask.astype(atom_property.dtype) + atom_property = paddle.where(mask, atom_property, 0.0) outs = ( outs + atom_property ) # Shape is [nframes, natoms[0], net_dim_out] # nf x nloc - mask = self.emask(atype) + mask = self.emask(atype).to("bool") # nf x nloc x nod - outs = outs * mask[:, :, None].astype(outs.dtype) + outs = paddle.where(mask[:, :, None], outs, 0.0) return {self.var_name: outs.astype(env.GLOBAL_PD_FLOAT_PRECISION)} diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 17d369751f..65e35a1c4b 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -3,9 +3,6 @@ import functools import logging import time -from contextlib import ( - contextmanager, -) from copy import ( deepcopy, ) @@ -53,7 +50,7 @@ ) from deepmd.pd.utils.dataloader import ( BufferedIterator, - get_weighted_sampler, + get_sampler_from_params, ) from deepmd.pd.utils.env import ( DEVICE, @@ -66,6 +63,7 @@ make_stat_input, ) from deepmd.pd.utils.utils import ( + nvprof_context, to_numpy_array, ) from deepmd.utils.data import ( @@ -87,6 +85,7 @@ def format_training_message( wall_time: float, eta: Optional[int] = None, ): + """Format a training message.""" msg = f"batch {batch:7d}: " f"total wall time = {wall_time:.2f} s" if isinstance(eta, int): msg += f", eta = {datetime.timedelta(seconds=int(eta))!s}" @@ -107,7 +106,7 @@ def __init__( shared_links=None, finetune_links=None, init_frz_model=None, - ): + ) -> None: """Construct a DeePMD trainer. Args: @@ -169,19 +168,7 @@ def get_opt_param(params): def get_data_loader(_training_data, _validation_data, _training_params): def get_dataloader_and_buffer(_data, _params): - if "auto_prob" in _training_params["training_data"]: - _sampler = get_weighted_sampler( - _data, _params["training_data"]["auto_prob"] - ) - elif "sys_probs" in _training_params["training_data"]: - _sampler = get_weighted_sampler( - _data, - _params["training_data"]["sys_probs"], - sys_prob=True, - ) - else: - _sampler = get_weighted_sampler(_data, "prob_sys_size") - + _sampler = get_sampler_from_params(_data, _params) if _sampler is None: log.warning( "Sampler not specified!" @@ -202,14 +189,16 @@ def get_dataloader_and_buffer(_data, _params): return _dataloader, _data_buffered training_dataloader, training_data_buffered = get_dataloader_and_buffer( - _training_data, _training_params + _training_data, _training_params["training_data"] ) if _validation_data is not None: ( validation_dataloader, validation_data_buffered, - ) = get_dataloader_and_buffer(_validation_data, _training_params) + ) = get_dataloader_and_buffer( + _validation_data, _training_params["validation_data"] + ) valid_numb_batch = _training_params["validation_data"].get( "numb_btch", 1 ) @@ -284,7 +273,7 @@ def get_lr(lr_params): self.opt_type, self.opt_param = get_opt_param(training_params) # Model - self.model = get_model_for_wrapper(model_params) + self.model = get_model_for_wrapper(model_params, resuming=resuming) # Loss if not self.multi_task: @@ -496,7 +485,7 @@ def collect_single_finetune_params( _new_state_dict, _origin_state_dict, _random_state_dict, - ): + ) -> None: _new_fitting = _finetune_rule_single.get_random_fitting() _model_key_from = _finetune_rule_single.get_model_branch() target_keys = [ @@ -669,10 +658,11 @@ def run(self): core.nvprof_start() core.nvprof_enable_record_event() - def step(_step_id, task_key="Default"): + def step(_step_id, task_key="Default") -> None: # Paddle Profiler if enable_profiling: core.nvprof_nvtx_push(f"Training step {_step_id}") + self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] @@ -707,20 +697,17 @@ def step(_step_id, task_key="Default"): if self.gradient_max_norm > 0.0: with nvprof_context(enable_profiling, "Gradient clip"): - grad_norm = paddle.nn.utils.clip_grad_norm_( - self.wrapper.parameters(), self.gradient_max_norm + paddle.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), + self.gradient_max_norm, + error_if_nonfinite=True, ) - if not paddle.isfinite(grad_norm).all(): - # check local gradnorm single GPU case, trigger NanDetector - raise FloatingPointError("gradients are Nan/Inf") with nvprof_context(enable_profiling, "Adam update"): self.optimizer.step() self.scheduler.step() - if enable_profiling: - core.nvprof_nvtx_pop() else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") @@ -729,7 +716,7 @@ def step(_step_id, task_key="Default"): if self.display_in_training and ( display_step_id % self.disp_freq == 0 or display_step_id == 1 ): - self.wrapper.eval() + self.wrapper.eval() # Will set to train mode before fininshing validation def log_loss_train(_loss, _more_loss, _task_key="Default"): results = {} @@ -835,6 +822,7 @@ def log_loss_valid(_task_key="Default"): learning_rate=None, ) ) + self.wrapper.train() current_time = time.time() train_time = current_time - self.t0 @@ -888,12 +876,16 @@ def log_loss_valid(_task_key="Default"): display_step_id % self.tensorboard_freq == 0 or display_step_id == 1 ): writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id) - writer.add_scalar(f"{task_key}/loss", loss, display_step_id) + writer.add_scalar(f"{task_key}/loss", loss.item(), display_step_id) for item in more_loss: writer.add_scalar( - f"{task_key}/{item}", more_loss[item].item(), _step_id + f"{task_key}/{item}", more_loss[item].item(), display_step_id ) + if enable_profiling: + core.nvprof_nvtx_pop() + + self.wrapper.train() self.t0 = time.time() self.total_train_time = 0.0 for step_id in range(self.num_steps): @@ -989,7 +981,7 @@ def log_loss_valid(_task_key="Default"): "files, which can be viewd in NVIDIA Nsight Systems software" ) - def save_model(self, save_path, lr=0.0, step=0): + def save_model(self, save_path, lr=0.0, step=0) -> None: module = ( self.wrapper.module if dist.is_available() and dist.is_initialized() @@ -1085,7 +1077,7 @@ def get_data(self, is_train=True, task_key="Default"): log_dict["sid"] = batch_data["sid"] return input_dict, label_dict, log_dict - def print_header(self, fout, train_results, valid_results): + def print_header(self, fout, train_results, valid_results) -> None: train_keys = sorted(train_results.keys()) print_str = "" print_str += "# {:5s}".format("step") @@ -1116,7 +1108,9 @@ def print_header(self, fout, train_results, valid_results): fout.write(print_str) fout.flush() - def print_on_training(self, fout, step_id, cur_lr, train_results, valid_results): + def print_on_training( + self, fout, step_id, cur_lr, train_results, valid_results + ) -> None: train_keys = sorted(train_results.keys()) print_str = "" print_str += f"{step_id:7d}" @@ -1191,7 +1185,7 @@ def get_single_model( return model -def get_model_for_wrapper(_model_params): +def get_model_for_wrapper(_model_params, resuming=False): if "model_dict" not in _model_params: _model = get_single_model( _model_params, @@ -1199,13 +1193,41 @@ def get_model_for_wrapper(_model_params): else: _model = {} model_keys = list(_model_params["model_dict"]) + do_case_embd, case_embd_index = get_case_embd_config(_model_params) for _model_key in model_keys: _model[_model_key] = get_single_model( _model_params["model_dict"][_model_key], ) + if do_case_embd and not resuming: + # only set case_embd when from scratch multitask training + _model[_model_key].set_case_embd(case_embd_index[_model_key]) return _model +def get_case_embd_config(_model_params): + assert ( + "model_dict" in _model_params + ), "Only support setting case embedding for multi-task model!" + model_keys = list(_model_params["model_dict"]) + sorted_model_keys = sorted(model_keys) + numb_case_embd_list = [ + _model_params["model_dict"][model_key] + .get("fitting_net", {}) + .get("dim_case_embd", 0) + for model_key in sorted_model_keys + ] + if not all(item == numb_case_embd_list[0] for item in numb_case_embd_list): + raise ValueError( + f"All models must have the same dimension of case embedding, while the settings are: {numb_case_embd_list}" + ) + if numb_case_embd_list[0] == 0: + return False, {} + case_embd_index = { + model_key: idx for idx, model_key in enumerate(sorted_model_keys) + } + return True, case_embd_index + + def model_change_out_bias( _model, _sample_func, @@ -1225,16 +1247,3 @@ def model_change_out_bias( f"to {to_numpy_array(new_bias).reshape(-1)!s}." ) return _model - - -@contextmanager -def nvprof_context(enable_profiler: bool, name: str): - if enable_profiler: - core.nvprof_nvtx_push(name) - - try: - yield - - finally: - if enable_profiler: - core.nvprof_nvtx_pop() diff --git a/deepmd/pd/train/wrapper.py b/deepmd/pd/train/wrapper.py index c3643f8372..2263a6e9b9 100644 --- a/deepmd/pd/train/wrapper.py +++ b/deepmd/pd/train/wrapper.py @@ -26,7 +26,7 @@ def __init__( loss: paddle.nn.Layer | dict = None, model_params=None, shared_links=None, - ): + ) -> None: """Construct a DeePMD model wrapper. Args: @@ -64,7 +64,7 @@ def __init__( self.loss[task_key] = loss[task_key] self.inference_only = self.loss is None - def share_params(self, shared_links, resume=False): + def share_params(self, shared_links, resume=False) -> None: """ Share the parameters of classes following rules defined in shared_links during multitask training. If not start from checkpoint (resume is False), @@ -111,8 +111,10 @@ def share_params(self, shared_links, resume=False): f"Shared params of {model_key_base}.{class_type_base} and {model_key_link}.{class_type_link}!" ) else: - if hasattr(self.model[model_key_base], class_type_base): - base_class = self.model[model_key_base].__getattr__(class_type_base) + if hasattr(self.model[model_key_base].atomic_model, class_type_base): + base_class = self.model[model_key_base].atomic_model.__getattr__( + class_type_base + ) for link_item in shared_links[shared_item]["links"][1:]: class_type_link = link_item["shared_type"] model_key_link = link_item["model_key"] @@ -123,9 +125,9 @@ def share_params(self, shared_links, resume=False): assert ( class_type_base == class_type_link ), f"Class type mismatched: {class_type_base} vs {class_type_link}!" - link_class = self.model[model_key_link].__getattr__( - class_type_link - ) + link_class = self.model[ + model_key_link + ].atomic_model.__getattr__(class_type_link) link_class.share_params( base_class, shared_level_link, resume=resume ) diff --git a/deepmd/pd/utils/dataloader.py b/deepmd/pd/utils/dataloader.py index 7a2bf4fe9c..9d59ea0da7 100644 --- a/deepmd/pd/utils/dataloader.py +++ b/deepmd/pd/utils/dataloader.py @@ -183,6 +183,7 @@ def __next__(self): return next(self.item) self.iters = [] + for item in self.dataloaders: self.iters.append(LazyIter(item)) @@ -196,7 +197,7 @@ def set_noise(self, noise_settings): for system in self.systems: system.set_noise(noise_settings) - def __len__(self): + def __len__(self) -> int: return len(self.dataloaders) def __getitem__(self, idx): @@ -219,19 +220,21 @@ def print_summary( name: str, prob: list[float], ): - print_summary( - name, - len(self.systems), - [ss.system for ss in self.systems], - [ss._natoms for ss in self.systems], - self.batch_sizes, - [ - ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) - for ii, ss in enumerate(self.systems) - ], - prob, - [ss._data_system.pbc for ss in self.systems], - ) + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + print_summary( + name, + len(self.systems), + [ss.system for ss in self.systems], + [ss._natoms for ss in self.systems], + self.batch_sizes, + [ + ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) + for ii, ss in enumerate(self.systems) + ], + prob, + [ss._data_system.pbc for ss in self.systems], + ) _sentinel = object() @@ -239,13 +242,13 @@ def print_summary( class BackgroundConsumer(Thread): - def __init__(self, queue, source, max_len): + def __init__(self, queue, source, max_len) -> None: Thread.__init__(self) self._queue = queue self._source = source # Main DL iterator self._max_len = max_len # - def run(self): + def run(self) -> None: for item in self._source: self._queue.put(item) # Blocking if the queue is full @@ -254,7 +257,7 @@ def run(self): class BufferedIterator: - def __init__(self, iterable): + def __init__(self, iterable) -> None: self._queue = queue.Queue(QUEUESIZE) self._iterable = iterable self._consumer = None @@ -263,7 +266,7 @@ def __init__(self, iterable): self.warning_time = None self.total = len(iterable) - def _create_consumer(self): + def _create_consumer(self) -> None: self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total) self._consumer.daemon = True self._consumer.start() @@ -271,7 +274,7 @@ def _create_consumer(self): def __iter__(self): return self - def __len__(self): + def __len__(self) -> int: return self.total def __next__(self): @@ -337,3 +340,19 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False): len_sampler = training_data.total_batch * max(env.NUM_WORKERS, 1) sampler = WeightedRandomSampler(probs, len_sampler, replacement=True) return sampler + + +def get_sampler_from_params(_data, _params): + if ( + "sys_probs" in _params and _params["sys_probs"] is not None + ): # use sys_probs first + _sampler = get_weighted_sampler( + _data, + _params["sys_probs"], + sys_prob=True, + ) + elif "auto_prob" in _params: + _sampler = get_weighted_sampler(_data, _params["auto_prob"]) + else: + _sampler = get_weighted_sampler(_data, "prob_sys_size") + return _sampler diff --git a/deepmd/pd/utils/decomp.py b/deepmd/pd/utils/decomp.py index 272c2deacb..3b7bddbcd1 100644 --- a/deepmd/pd/utils/decomp.py +++ b/deepmd/pd/utils/decomp.py @@ -10,100 +10,17 @@ annotations, ) +import numpy as np import paddle __all__ = [ "masked_add_", - "norm", + "numel", "scatter_reduce", "sec", - "softmax", - "take_along_axis", ] -# decomposition for forward function -def softmax_decomp(x: paddle.Tensor, axis: int = -1) -> paddle.Tensor: - """Forward decompsition function of softmax. - - Parameters - ---------- - x : paddle.Tensor - Input. - axis : int, defaults: -1. - A dimension along which softmax will be computed. - - Returns - ------- - paddle.Tensor - Computed output. - """ - x_max = paddle.max(x, axis=axis, keepdim=True) - x = x - x_max - return paddle.exp(x) / paddle.sum(paddle.exp(x), axis=axis, keepdim=True) - - -def norm_decomp( - x: paddle.Tensor, p: float = 2, axis: bool = -1, keepdim: bool = False -) -> paddle.Tensor: - """Forward decompsition function of norm. - - Parameters - ---------- - x : paddle.Tensor - Input - p : float, default: 2 - Order of norm - axis : bool, default: -1 - Dimensions over which to compute the vector or matrix norm - keepdim : bool, default: False - If set to True, the reduced dimensions are retained in the result as dimensions - with size one - - Returns - ------- - paddle.Tensor - A real-valued tensor, even when A is complex. - """ - if p == 2 or p == 2.0: - # clip for negative indexing, or 1/(0^(k-1)) will cause inf in backward - return (x * x).sum(axis=axis, keepdim=keepdim) ** 0.5 - return (x.abs() ** p).sum(axis=axis, keepdim=keepdim) ** (1 / p) - - -def take_along_axis_decomp( - x: paddle.Tensor, indices: paddle.Tensor, axis: int, broadcast: bool = True -) -> paddle.Tensor: - """Forward decompsition function of take_along_axis. - - Parameters - ---------- - x : paddle.Tensor - The input tensor. - indices : paddle.Tensor - Indices to take along each 1d slice of array. - axis : int - The axis to take 1d slices along. - broadcast : bool, default: True - Whether the indices broadcast. - - Returns - ------- - paddle.Tensor - Computed output. - """ - # manually contruct indices for gather_nd(ind_gather_nd.ndim == indices.ndim + 1, - # the lsat 1 represents the number of dimension(s) of indices) - ind_gather_nd = paddle.stack( - paddle.meshgrid(*[paddle.arange(v) for v in indices.shape], indexing="ij"), - axis=-1, - ) - ind_gather_nd[..., axis] = indices - # compute output using constructed indices via gather_nd - out = paddle.gather_nd(x, ind_gather_nd) - return out - - def scatter_reduce_decomp( input: paddle.Tensor, axis: int, @@ -210,38 +127,13 @@ def masked_add__decomp( return x -def normalize_decomp( - x: paddle.Tensor, - p: float = 2, - axis: int = 1, - epsilon: float = 1e-12, -) -> paddle.Tensor: - """Forward decompsition function of normalize. - - Parameters - ---------- - x : paddle.Tensor - Input tensor. - p : float, optional - Order of the norm, default: 2 - axis : int, optional - Axis on which to perform normalization, default: 1 - epsilon : float, optional - Epislon value, default: 1e-12 +def numel(x: paddle.Tensor) -> int: + if paddle.in_dynamic_mode(): + return np.prod(x.shape) - Returns - ------- - paddle.Tensor - Computed output. - """ - return paddle.nn.functional.normalize(x, p, axis, epsilon) - # return x / norm(x, p=p, axis=axis, keepdim=True) + return paddle.numel(x) # alias for decomposed functions for convinience -normalize = normalize_decomp masked_add_ = masked_add__decomp scatter_reduce = scatter_reduce_decomp -take_along_axis = take_along_axis_decomp -norm = norm_decomp -softmax = softmax_decomp diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 6dbdc69f30..041c231282 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -77,13 +77,75 @@ def enable_prim(enable: bool = True): + # NOTE: operator in list below will not use composite + # operator but kernel instead + EAGER_COMP_OP_BLACK_LIST = [ + "abs_grad", + "cast_grad", + # "concat_grad", + "cos_double_grad", + "cos_grad", + "cumprod_grad", + "cumsum_grad", + "dropout_grad", + "erf_grad", + "exp_grad", + "expand_grad", + "floor_grad", + "gather_grad", + "gather_nd_grad", + "gelu_grad", + "group_norm_grad", + "instance_norm_grad", + "layer_norm_grad", + "leaky_relu_grad", + "log_grad", + "max_grad", + "pad_grad", + "pow_double_grad", + "pow_grad", + "prod_grad", + "relu_grad", + "roll_grad", + "rsqrt_grad", + "scatter_grad", + "scatter_nd_add_grad", + "sigmoid_grad", + "silu_grad", + "sin_double_grad", + "sin_grad", + "slice_grad", + # "split_grad", + "sqrt_grad", + "stack_grad", + "sum_grad", + "tanh_double_grad", + "tanh_grad", + "topk_grad", + "transpose_grad", + "add_double_grad", + "add_grad", + "assign_grad", + "batch_norm_grad", + "divide_grad", + "elementwise_pow_grad", + "maximum_grad", + "min_grad", + "minimum_grad", + "multiply_grad", + "subtract_grad", + "tile_grad", + ] + EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) + """Enable running program in primitive C++ API in eager/static mode.""" from paddle.framework import ( core, ) core.set_prim_eager_enabled(enable) - core._set_prim_all_enabled(enable) + if enable: + paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST) log = logging.getLogger(__name__) log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.") diff --git a/deepmd/pd/utils/exclude_mask.py b/deepmd/pd/utils/exclude_mask.py index 088ac186a8..29c9cc3501 100644 --- a/deepmd/pd/utils/exclude_mask.py +++ b/deepmd/pd/utils/exclude_mask.py @@ -3,9 +3,6 @@ import numpy as np import paddle -from deepmd.pd.utils import ( - decomp, -) from deepmd.pd.utils.utils import ( to_paddle_tensor, ) @@ -18,7 +15,7 @@ def __init__( self, ntypes: int, exclude_types: list[int] = [], - ): + ) -> None: super().__init__() self.reinit(ntypes, exclude_types) @@ -26,7 +23,7 @@ def reinit( self, ntypes: int, exclude_types: list[int] = [], - ): + ) -> None: self.ntypes = ntypes self.exclude_types = exclude_types self.type_mask = np.array( @@ -71,7 +68,7 @@ def __init__( self, ntypes: int, exclude_types: list[tuple[int, int]] = [], - ): + ) -> None: super().__init__() self.reinit(ntypes, exclude_types) @@ -79,7 +76,7 @@ def reinit( self, ntypes: int, exclude_types: list[tuple[int, int]] = [], - ): + ) -> None: self.ntypes = ntypes self._exclude_types: set[tuple[int, int]] = set() for tt in exclude_types: @@ -137,19 +134,14 @@ def forward( [ atype_ext, self.ntypes - * paddle.ones([nf, 1], dtype=atype_ext.dtype).to( - device=atype_ext.place - ), + * paddle.ones([nf, 1], dtype=atype_ext.dtype).to(atype_ext.place), ], axis=-1, ) type_i = atype_ext[:, :nloc].reshape([nf, nloc]) * (self.ntypes + 1) # nf x nloc x nnei index = paddle.where(nlist == -1, nall, nlist).reshape([nf, nloc * nnei]) - # type_j = paddle.take_along_axis(ae, axis=1, indices=index).reshape( - # [nf, nloc, nnei] - # ) - type_j = decomp.take_along_axis(ae, axis=1, indices=index).reshape( + type_j = paddle.take_along_axis(ae, axis=1, indices=index).reshape( [nf, nloc, nnei] ) type_ij = type_i[:, :, None] + type_j diff --git a/deepmd/pd/utils/nlist.py b/deepmd/pd/utils/nlist.py index 44924ce07d..ae9db628a1 100644 --- a/deepmd/pd/utils/nlist.py +++ b/deepmd/pd/utils/nlist.py @@ -7,7 +7,6 @@ import paddle from deepmd.pd.utils import ( - decomp, env, ) from deepmd.pd.utils.region import ( @@ -118,8 +117,7 @@ def build_neighbor_list( if paddle.in_dynamic_mode(): assert list(diff.shape) == [batch_size, nloc, nall, 3] # nloc x nall - # rr = paddle.linalg.norm(diff, axis=-1) - rr = decomp.norm(diff, axis=-1) + rr = paddle.linalg.norm(diff, axis=-1) # if central atom has two zero distances, sorting sometimes can not exclude itself rr = rr - paddle.eye(nloc, nall, dtype=rr.dtype).to(device=rr.place).unsqueeze(0) rr, nlist = paddle.sort(rr, axis=-1), paddle.argsort(rr, axis=-1) @@ -267,8 +265,7 @@ def build_directional_neighbor_list( if paddle.in_dynamic_mode(): assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3] # nloc x nall - # rr = paddle.linalg.norm(diff, axis=-1) - rr = decomp.norm(diff, axis=-1) + rr = paddle.linalg.norm(diff, axis=-1) rr, nlist = paddle.sort(rr, axis=-1), paddle.argsort(rr, axis=-1) # We assume that the central and neighbor atoms are diffferent, @@ -300,12 +297,7 @@ def nlist_distinguish_types( tmp_atype = paddle.tile(atype.unsqueeze(1), [1, nloc, 1]) mask = nlist == -1 # nloc x s(nsel) - # tnlist = paddle.take_along_axis( - # tmp_atype, - # axis=2, - # indices=nlist.masked_fill(mask, 0), - # ) - tnlist = decomp.take_along_axis( + tnlist = paddle.take_along_axis( tmp_atype, axis=2, indices=nlist.masked_fill(mask, 0), @@ -322,8 +314,7 @@ def nlist_distinguish_types( paddle.argsort(pick_mask, axis=-1, descending=True, stable=True), ) # nloc x s(nsel) - # inlist = paddle.take_along_axis(nlist, axis=2, indices=imap) - inlist = decomp.take_along_axis(nlist, axis=2, indices=imap) + inlist = paddle.take_along_axis(nlist, axis=2, indices=imap) inlist = inlist.masked_fill(~(pick_mask.to(paddle.bool)), -1) # nloc x nsel[ii] ret_nlist.append(paddle.split(inlist, [ss, snsel - ss], axis=-1)[0]) @@ -404,17 +395,13 @@ def build_multiple_neighbor_list( .expand([-1, -1, 3]) ) # nb x nloc x nsel x 3 - # coord2 = paddle.take_along_axis(coord1, axis=1, index=index).reshape( - # [nb, nloc, nsel, 3] - # ) - coord2 = decomp.take_along_axis(coord1, axis=1, indices=index).reshape( + coord2 = paddle.take_along_axis(coord1, axis=1, indices=index).reshape( [nb, nloc, nsel, 3] ) # nb x nloc x nsel x 3 diff = coord2 - coord0[:, :, None, :] # nb x nloc x nsel - # rr = paddle.linalg.norm(diff, axis=-1) - rr = decomp.norm(diff, axis=-1) + rr = paddle.linalg.norm(diff, axis=-1) rr.masked_fill(nlist_mask, float("inf")) nlist0 = nlist ret = {} @@ -516,8 +503,7 @@ def extend_coord_with_ghosts( xyz = xyz.reshape([-1, 3]) # xyz = xyz.to(device=device) # ns x 3 - # shift_idx = xyz[paddle.argsort(paddle.norm(xyz, axis=1))] - shift_idx = xyz[paddle.argsort(decomp.norm(xyz, axis=1))] + shift_idx = xyz[paddle.argsort(paddle.norm(xyz, axis=1))] ns, _ = shift_idx.shape nall = ns * nloc # nf x ns x 3 diff --git a/deepmd/pd/utils/region.py b/deepmd/pd/utils/region.py index 21927e3619..f3e3eaa52d 100644 --- a/deepmd/pd/utils/region.py +++ b/deepmd/pd/utils/region.py @@ -1,10 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import paddle -from deepmd.pd.utils import ( - decomp, -) - def phys2inter( coord: paddle.Tensor, @@ -82,14 +78,11 @@ def to_face_distance( def b_to_face_distance(cell): volume = paddle.linalg.det(cell) c_yz = paddle.cross(cell[:, 1], cell[:, 2], axis=-1) - # _h2yz = volume / paddle.linalg.norm(c_yz, axis=-1) - _h2yz = volume / decomp.norm(c_yz, axis=-1) + _h2yz = volume / paddle.linalg.norm(c_yz, axis=-1) c_zx = paddle.cross(cell[:, 2], cell[:, 0], axis=-1) - # _h2zx = volume / paddle.linalg.norm(c_zx, axis=-1) - _h2zx = volume / decomp.norm(c_zx, axis=-1) + _h2zx = volume / paddle.linalg.norm(c_zx, axis=-1) c_xy = paddle.cross(cell[:, 0], cell[:, 1], axis=-1) - # _h2xy = volume / paddle.linalg.norm(c_xy, axis=-1) - _h2xy = volume / decomp.norm(c_xy, axis=-1) + _h2xy = volume / paddle.linalg.norm(c_xy, axis=-1) return paddle.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index 48732ff84e..87072eb3cd 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -3,6 +3,9 @@ annotations, ) +from contextlib import ( + contextmanager, +) from typing import ( TYPE_CHECKING, overload, @@ -12,6 +15,9 @@ import numpy as np import paddle import paddle.nn.functional as F +from paddle.framework import ( + core, +) from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT @@ -177,3 +183,16 @@ def get_generator( return generator else: return None + + +@contextmanager +def nvprof_context(enable_profiler: bool, name: str): + if enable_profiler: + core.nvprof_nvtx_push(name) + + try: + yield + + finally: + if enable_profiler: + core.nvprof_nvtx_pop() diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 8be219f5ea..92b2c6bd0b 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -18,6 +18,7 @@ from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -39,6 +40,10 @@ from deepmd.jax.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1JAX else: DescriptorDPA1JAX = None +if INSTALLED_PD: + from deepmd.pd.model.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PD +else: + DescrptDPA1PD = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict else: @@ -187,6 +192,34 @@ def skip_dp(self) -> bool: temperature, ) + @property + def skip_pd(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PD or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + temperature, + ) + @property def skip_jax(self) -> bool: ( @@ -287,6 +320,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA1TF dp_class = DescrptDPA1DP pt_class = DescrptDPA1PT + pd_class = DescrptDPA1PD jax_class = DescriptorDPA1JAX array_api_strict_class = DescriptorDPA1Strict @@ -387,6 +421,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index 774c624ac7..8b8fab7ae1 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -14,6 +14,7 @@ from ..common import ( INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, INSTALLED_TF, SKIP_FLAG, @@ -37,6 +38,11 @@ model_args, ) +if INSTALLED_PD: + from deepmd.pd.model.model import get_model as get_model_pd + from deepmd.pd.model.model.ener_model import EnergyModel as EnergyModelPD +else: + EnergyModelPD = None if INSTALLED_JAX: from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX from deepmd.jax.model.model import get_model as get_model_jax @@ -90,6 +96,7 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + pd_class = EnergyModelPD jax_class = EnergyModelJAX args = model_args() @@ -102,6 +109,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_tf: return self.RefBackend.TF + if not self.skip_pd: + return self.RefBackend.PD if not self.skip_jax: return self.RefBackend.JAX if not self.skip_dp: @@ -119,6 +128,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelPD: + return get_model_pd(data) elif cls is EnergyModelJAX: return get_model_jax(data) return cls(**data, **self.additional_data) @@ -190,6 +201,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_model( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_model( jax_obj, @@ -225,6 +245,14 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret[3].ravel(), ret[4].ravel(), ) + elif backend is self.RefBackend.PD: + return ( + ret["energy"].flatten(), + ret["atom_energy"].flatten(), + ret["force"].flatten(), + ret["virial"].flatten(), + ret["atom_virial"].flatten(), + ) elif backend is self.RefBackend.JAX: return ( ret["energy_redu"].ravel(), diff --git a/source/tests/pd/common.py b/source/tests/pd/common.py index 59a9672330..d73544c5f1 100644 --- a/source/tests/pd/common.py +++ b/source/tests/pd/common.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import pathlib from typing import ( Optional, Union, @@ -7,6 +8,7 @@ import numpy as np import paddle +from deepmd.common import j_loader as dp_j_loader from deepmd.main import ( main, ) @@ -15,6 +17,12 @@ GLOBAL_PD_FLOAT_PRECISION, ) +tests_path = pathlib.Path(__file__).parent.absolute() + + +def j_loader(filename): + return dp_j_loader(tests_path / filename) + def run_dp(cmd: str) -> int: """Run DP directly from the entry point instead of the subprocess. diff --git a/source/tests/pd/model/models/dpa1.json b/source/tests/pd/model/models/dpa1.json new file mode 100644 index 0000000000..a969c290ae --- /dev/null +++ b/source/tests/pd/model/models/dpa1.json @@ -0,0 +1,36 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_atten", + "sel": 30, + "rcut_smth": 2.0, + "rcut": 6.0, + "neuron": [ + 2, + 4, + 8 + ], + "axis_neuron": 4, + "attn": 5, + "attn_layer": 2, + "attn_dotr": true, + "attn_mask": false, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": true, + "temperature": 1.0, + "seed": 1 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pd/model/models/dpa1.pd b/source/tests/pd/model/models/dpa1.pd new file mode 100644 index 0000000000..147312635c Binary files /dev/null and b/source/tests/pd/model/models/dpa1.pd differ diff --git a/source/tests/pd/model/models/dpa2_tebd.pd b/source/tests/pd/model/models/dpa2_tebd.pd new file mode 100644 index 0000000000..2d3149d9a4 Binary files /dev/null and b/source/tests/pd/model/models/dpa2_tebd.pd differ diff --git a/source/tests/pd/model/test_atomic_model_atomic_stat.py b/source/tests/pd/model/test_atomic_model_atomic_stat.py new file mode 100644 index 0000000000..93aa7b8905 --- /dev/null +++ b/source/tests/pd/model/test_atomic_model_atomic_stat.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import paddle + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pd.model.atomic_model import ( + BaseAtomicModel, + DPAtomicModel, +) +from deepmd.pd.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pd.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.utils import ( + to_numpy_array, + to_paddle_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class FooFitting(paddle.nn.Layer, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + raise NotImplementedError + + def get_type_map(self) -> list[str]: + raise NotImplementedError + + def forward( + self, + descriptor: paddle.Tensor, + atype: paddle.Tensor, + gr: Optional[paddle.Tensor] = None, + g2: Optional[paddle.Tensor] = None, + h2: Optional[paddle.Tensor] = None, + fparam: Optional[paddle.Tensor] = None, + aparam: Optional[paddle.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + paddle.to_tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .reshape([nf, nloc, *self.output_def()["foo"].shape]) + .to(env.GLOBAL_PD_FLOAT_PRECISION) + .to(env.DEVICE) + ) + ret["bar"] = ( + paddle.to_tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .reshape([nf, nloc, *self.output_def()["bar"].shape]) + .to(env.GLOBAL_PD_FLOAT_PRECISION) + .to(env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_paddle_tensor(np.zeros([2, 3, 3])), + "atype": to_paddle_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_paddle_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_paddle_tensor(np.zeros([2, 3, 3])), + "natoms": to_paddle_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 + "atom_foo": to_paddle_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_paddle_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_paddle_tensor(np.zeros([2, 3, 3])), + "atype": to_paddle_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_paddle_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_paddle_tensor(np.zeros([2, 3, 3])), + "natoms": to_paddle_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5, 6 from atomic label. + "foo": to_paddle_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_paddle_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + expected_std = np.ones( + (2, 2, 2), dtype=np.float64 + ) # 2 keys, 2 atypes, 2 max dims. + expected_std[0, :, :1] = np.array([0.0, 0.816496]).reshape( + 2, 1 + ) # updating std for foo based on [5.0, 5.0, 5.0], [5.0, 6.0, 7.0]] + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.0, 6.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_paddle_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_std[0, :, :1] = np.array([1.24722, 0.47140]).reshape( + 2, 1 + ) # updating std for foo based on [4.0, 3.0, 2.0], [1.0, 1.0, 1.0]] + expected_ret3 = {} + # new bias [2.666, 1.333] + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + np.testing.assert_almost_equal( + to_numpy_array(md0.out_std), expected_std, decimal=4 + ) + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + self.merged_output_stat = [ + { + "coord": to_paddle_tensor(np.zeros([2, 3, 3])), + "atype": to_paddle_tensor( + np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int32) + ), + "atype_ext": to_paddle_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_paddle_tensor(np.zeros([2, 3, 3])), + "natoms": to_paddle_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, nan + "atom_foo": to_paddle_tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1) + ), + # bias of bar: [1, 5], [3, 2] + "bar": to_paddle_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": to_paddle_tensor(np.zeros([2, 3, 3])), + "atype": to_paddle_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_paddle_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_paddle_tensor(np.zeros([2, 3, 3])), + "natoms": to_paddle_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 5.5, 3 from atomic label. + "foo": to_paddle_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # bias of bar: [1, 5], [3, 2] + "bar": to_paddle_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.5, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_paddle_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_ret3 = {} + # new bias [2, -5] + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) diff --git a/source/tests/pd/model/test_atomic_model_global_stat.py b/source/tests/pd/model/test_atomic_model_global_stat.py new file mode 100644 index 0000000000..abd7928a0f --- /dev/null +++ b/source/tests/pd/model/test_atomic_model_global_stat.py @@ -0,0 +1,510 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + Optional, +) + +import h5py +import numpy as np +import paddle + +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pd.model.atomic_model import ( + BaseAtomicModel, + DPAtomicModel, +) +from deepmd.pd.model.descriptor import ( + DescrptDPA1, + DescrptSeA, +) +from deepmd.pd.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pd.model.task.ener import ( + InvarFitting, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.utils import ( + to_numpy_array, + to_paddle_tensor, +) +from deepmd.utils.path import ( + DPPath, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class FooFitting(paddle.nn.Layer, BaseFitting): + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "pix", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + raise NotImplementedError + + def get_type_map(self) -> list[str]: + raise NotImplementedError + + def forward( + self, + descriptor: paddle.Tensor, + atype: paddle.Tensor, + gr: Optional[paddle.Tensor] = None, + g2: Optional[paddle.Tensor] = None, + h2: Optional[paddle.Tensor] = None, + fparam: Optional[paddle.Tensor] = None, + aparam: Optional[paddle.Tensor] = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + paddle.to_tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .reshape([nf, nloc] + self.output_def()["foo"].shape) # noqa: RUF005 + .to(env.GLOBAL_PD_FLOAT_PRECISION) + .to(env.DEVICE) + ) + ret["pix"] = ( + paddle.to_tensor( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ) + .reshape([nf, nloc] + self.output_def()["pix"].shape) # noqa: RUF005 + .to(env.GLOBAL_PD_FLOAT_PRECISION) + .to(env.DEVICE) + ) + ret["bar"] = ( + paddle.to_tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .reshape([nf, nloc] + self.output_def()["bar"].shape) # noqa: RUF005 + .to(env.GLOBAL_PD_FLOAT_PRECISION) + .to(env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self): + self.tempdir.cleanup() + + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + nf, nloc, nnei = self.nlist.shape + self.merged_output_stat = [ + { + "coord": to_paddle_tensor(np.zeros([2, 3, 3])), + "atype": to_paddle_tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32) + ), + "atype_ext": to_paddle_tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32) + ), + "box": to_paddle_tensor(np.zeros([2, 3, 3])), + "natoms": to_paddle_tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32) + ), + # bias of foo: 1, 3 + "foo": to_paddle_tensor(np.array([5.0, 7.0]).reshape(2, 1)), + # no bias of pix + # bias of bar: [1, 5], [3, 2] + "bar": to_paddle_tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2) + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + } + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["foo"].shape) # noqa: RUF005 + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["pix"].shape) # noqa: RUF005 + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["bar"].shape) # noqa: RUF005 + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims. + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_paddle_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2, 3, 6], [5, 8, 9]] given bias [1, 3] + ## foo sumed: [11, 22] compared with [5, 7], fit target is [-6, -15] + ## fit bias is [1, -8] + ## old bias + fit bias [2, -5] + ## new model output is [[3, 4, -2], [6, 0, 1]], which sumed to [5, 7] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + np.testing.assert_almost_equal(to_numpy_array(md0.out_std), expected_std) + + def test_preset_bias(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + preset_out_bias = { + # "foo": np.array(3.0, 2.0]).reshape(2, 1), + "foo": [None, 2], + "bar": np.array([7.0, 5.0, 13.0, 11.0]).reshape(2, 1, 2), + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["foo"].shape) # noqa: RUF005 + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["pix"].shape) # noqa: RUF005 + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["bar"].shape) # noqa: RUF005 + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # foo sums: [5, 7], + # given bias of type 1 being 2, the bias left for type 0 is [5-2*1, 7-2*2] = [3,3] + # the solution of type 0 is 1.8 + foo_bias = np.array([1.8, preset_out_bias["foo"][1]]).reshape(2, 1) + bar_bias = preset_out_bias["bar"] + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error(): + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + BaseAtomicModel.change_out_bias( + md0, self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + args = [ + to_paddle_tensor(ii) + for ii in [ + self.coord_ext, + to_numpy_array(self.merged_output_stat[0]["atype_ext"]), + self.nlist, + ] + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2.8, 3.8, 5], [5.8, 7., 8.]] given bias [1.8, 2] + ## foo sumed: [11.6, 20.8] compared with [5, 7], fit target is [-6.6, -13.8] + ## fit bias is [-7, 2] (2 is assigned. -7 is fit to [-8.6, -17.8]) + ## old bias[1.8,2] + fit bias[-7, 2] = [-5.2, 4] + ## new model output is [[-4.2, -3.2, 7], [-1.2, 9, 10]] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[-4.2, -3.2, 7.0], [-1.2, 9.0, 10.0]]).reshape( + 2, 3, 1 + ) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + + def test_preset_bias_all_none(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + sum(self.sel), + self.nt, + ).to(env.DEVICE) + ft = FooFitting().to(env.DEVICE) + type_map = ["foo", "bar"] + preset_out_bias = { + "foo": [None, None], + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["foo"].shape) # noqa: RUF005 + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["pix"].shape) # noqa: RUF005 + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc] + md0.fitting_output_def()["bar"].shape) # noqa: RUF005 + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + def test_serialize(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "foo", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["A", "B"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(env.DEVICE) + args = [ + to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: to_numpy_array(vv) for kk, vv in x.items()} + + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + md1 = DPAtomicModel.deserialize(md0.serialize()) + ret1 = md1.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret1[kk]) + + md2 = DPDPAtomicModel.deserialize(md0.serialize()) + args = [self.coord_ext, self.atype_ext, self.nlist] + ret2 = md2.forward_common_atomic(*args) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret2[kk]) diff --git a/source/tests/pd/model/test_autodiff.py b/source/tests/pd/model/test_autodiff.py index a056491fb3..1bd9dd0d0f 100644 --- a/source/tests/pd/model/test_autodiff.py +++ b/source/tests/pd/model/test_autodiff.py @@ -190,7 +190,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -198,7 +197,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pd/model/test_descriptor.py b/source/tests/pd/model/test_descriptor.py index 10f2fd271b..dc78856851 100644 --- a/source/tests/pd/model/test_descriptor.py +++ b/source/tests/pd/model/test_descriptor.py @@ -17,7 +17,6 @@ prod_env_mat, ) from deepmd.pd.utils import ( - decomp, dp_random, env, ) @@ -179,7 +178,7 @@ def test_consistency(self): my_nlist = nlist.reshape([bsz, -1]).cpu() mask = my_nlist == -1 my_nlist = my_nlist * (~mask).astype(my_nlist.dtype) - my_nlist = decomp.take_along_axis(mapping, axis=-1, indices=my_nlist) + my_nlist = paddle.take_along_axis(mapping, axis=-1, indices=my_nlist) my_nlist = my_nlist * (~mask).astype(my_nlist.dtype) - mask.astype( my_nlist.dtype ) diff --git a/source/tests/pd/model/test_descriptor_dpa1.py b/source/tests/pd/model/test_descriptor_dpa1.py new file mode 100644 index 0000000000..bfcf4ba6ee --- /dev/null +++ b/source/tests/pd/model/test_descriptor_dpa1.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import paddle + +from deepmd.pd.model.descriptor import ( + DescrptBlockSeAtten, + DescrptDPA1, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +CUR_DIR = os.path.dirname(__file__) + + +class TestDPA1(unittest.TestCase): + def setUp(self): + cell = [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + self.cell = ( + paddle.to_tensor( + cell, + dtype=env.GLOBAL_PD_FLOAT_PRECISION, + ) + .to(device=env.DEVICE) + .reshape([1, 3, 3]) + ) + coord = [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + self.coord = paddle.to_tensor( + coord, dtype=env.GLOBAL_PD_FLOAT_PRECISION, place=env.DEVICE + ).reshape([1, -1, 3]) + self.atype = paddle.to_tensor( + [0, 0, 0, 1, 1], dtype=paddle.int32, place=env.DEVICE + ).reshape([1, -1]) + self.ref_d = paddle.to_tensor( + [ + 8.382518544113587780e-03, + -3.390120566088597812e-03, + 6.145981571114964362e-03, + -4.880300873973819273e-03, + -3.390120566088597812e-03, + 1.372540996564941464e-03, + -2.484163690574096341e-03, + 1.972313058658722688e-03, + 6.145981571114964362e-03, + -2.484163690574096341e-03, + 4.507748738021747671e-03, + -3.579717194906019764e-03, + -4.880300873973819273e-03, + 1.972313058658722688e-03, + -3.579717194906019764e-03, + 2.842794615687799838e-03, + 6.733043802494966066e-04, + -2.721540313345096771e-04, + 4.936158526085561134e-04, + -3.919743287822345223e-04, + -1.311123004527576900e-02, + 5.301179352601203924e-03, + -9.614612349318877454e-03, + 7.634884975521277241e-03, + 8.877088452901006621e-03, + -3.590945566653638409e-03, + 6.508042782015627942e-03, + -5.167671664327699171e-03, + -2.697241463040870365e-03, + 1.091350446825975137e-03, + -1.976895708961905022e-03, + 1.569671412121975348e-03, + 8.645131636261189911e-03, + -3.557395265621639355e-03, + 6.298048561552698106e-03, + -4.999272007935521948e-03, + -3.557395265621639355e-03, + 1.467866637220284964e-03, + -2.587004431651147504e-03, + 2.052752235601402672e-03, + 6.298048561552698106e-03, + -2.587004431651147504e-03, + 4.594085551315935101e-03, + -3.647656549789176847e-03, + -4.999272007935521948e-03, + 2.052752235601402672e-03, + -3.647656549789176847e-03, + 2.896359275520481256e-03, + 6.689620176492027878e-04, + -2.753606422414641049e-04, + 4.864958810186969444e-04, + -3.860599754167503119e-04, + -1.349238259226558101e-02, + 5.547478630961994242e-03, + -9.835472300819447095e-03, + 7.808197926069362048e-03, + 9.220744348752592245e-03, + -3.795799103392961601e-03, + 6.716516319358462918e-03, + -5.331265718473574867e-03, + -2.783836698392940304e-03, + 1.147461939123531121e-03, + -2.025013030986024063e-03, + 1.606944814423778541e-03, + 9.280385723343491378e-03, + -3.515852178447095942e-03, + 7.085282215778941628e-03, + -5.675852414643783178e-03, + -3.515852178447095942e-03, + 1.337760635271160884e-03, + -2.679428786337713451e-03, + 2.145400621815936413e-03, + 7.085282215778941628e-03, + -2.679428786337713451e-03, + 5.414439648102228192e-03, + -4.338426468139268931e-03, + -5.675852414643783178e-03, + 2.145400621815936413e-03, + -4.338426468139268931e-03, + 3.476467482674507146e-03, + 7.166961981167455130e-04, + -2.697932188839837972e-04, + 5.474643906631899504e-04, + -4.386556623669893621e-04, + -1.480434821331240956e-02, + 5.604647062899507579e-03, + -1.130745349141585449e-02, + 9.059113563516829268e-03, + 9.758791063112262978e-03, + -3.701477720487638626e-03, + 7.448215522796466058e-03, + -5.966057584545172120e-03, + -2.845102393948158344e-03, + 1.078743584169829543e-03, + -2.170093031447992756e-03, + 1.738010461687942770e-03, + 9.867599071916231118e-03, + -3.811041717688905522e-03, + 7.121877634386481262e-03, + -5.703120290113914553e-03, + -3.811041717688905522e-03, + 1.474046183772771213e-03, + -2.747386907428428938e-03, + 2.199711055637492037e-03, + 7.121877634386481262e-03, + -2.747386907428428938e-03, + 5.145050639440944609e-03, + -4.120642824501622239e-03, + -5.703120290113914553e-03, + 2.199711055637492037e-03, + -4.120642824501622239e-03, + 3.300262321758350853e-03, + 1.370499995344566383e-03, + -5.313041843655797901e-04, + 9.860110343046961986e-04, + -7.892505817954784597e-04, + -1.507686316307561489e-02, + 5.818961290579217904e-03, + -1.088774506142304276e-02, + 8.719460408506790952e-03, + 9.764630842803939323e-03, + -3.770134041110058572e-03, + 7.049438389985595785e-03, + -5.645302934019884485e-03, + -3.533582373572779437e-03, + 1.367148320603491559e-03, + -2.546602904764623705e-03, + 2.038882844528267305e-03, + 7.448297038731285964e-03, + -2.924276815200288742e-03, + 5.355960540523636154e-03, + -4.280386435083473329e-03, + -2.924276815200288742e-03, + 1.150311064893848757e-03, + -2.100635980860638373e-03, + 1.678427895009850001e-03, + 5.355960540523636154e-03, + -2.100635980860638373e-03, + 3.853607053247790071e-03, + -3.080076301871465493e-03, + -4.280386435083473329e-03, + 1.678427895009850001e-03, + -3.080076301871465493e-03, + 2.461876613756722523e-03, + 9.730712866459405395e-04, + -3.821759579990726546e-04, + 6.994242056622360787e-04, + -5.589662297882965055e-04, + -1.138916742131982317e-02, + 4.469391132927387489e-03, + -8.192016282448397885e-03, + 6.547234460517113892e-03, + 7.460070829043288082e-03, + -2.929867802018087421e-03, + 5.363646855497249989e-03, + -4.286347242903034739e-03, + -2.643569023340565718e-03, + 1.038826463247002245e-03, + -1.899910089750410976e-03, + 1.518237240362583541e-03, + ], + dtype=env.GLOBAL_PD_FLOAT_PRECISION, + place=env.DEVICE, + ) + with open(Path(CUR_DIR) / "models" / "dpa1.json") as fp: + self.model_json = json.load(fp) + self.file_model_param = Path(CUR_DIR) / "models" / "dpa1.pd" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pd" + + def test_descriptor_block(self) -> None: + # paddle.seed(0) + model_dpa1 = self.model_json + dparams = model_dpa1["descriptor"] + ntypes = len(model_dpa1["type_map"]) + assert "se_atten" == dparams["type"] + dparams.pop("type") + dparams["ntypes"] = ntypes + des = DescrptBlockSeAtten( + **dparams, + ).to(env.DEVICE) + state_dict = paddle.load(str(self.file_model_param)) + # this is an old state dict, modify manually + state_dict["compress_info.0"] = des.compress_info[0] + state_dict["compress_data.0"] = des.compress_data[0] + des.set_state_dict(state_dict) + coord = self.coord + atype = self.atype + box = self.cell + # handle type_embedding + type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE) + type_embedding.set_state_dict(paddle.load(str(self.file_type_embed))) + + ## to save model parameters + # paddle.save(des.state_dict(), 'model_weights.pd') + # paddle.save(type_embedding.state_dict(), 'model_weights.pd') + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + mixed_types=des.mixed_types(), + box=box, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + nlist, + extended_coord, + extended_atype, + type_embedding(extended_atype), + mapping=None, + ) + # np.savetxt('tmp.out', descriptor.detach().numpy().reshape(1,-1), delimiter=",") + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntypes()) + np.testing.assert_allclose( + descriptor.reshape([-1]).numpy(), self.ref_d.numpy(), atol=1e-10, rtol=1e-10 + ) + + def test_descriptor(self) -> None: + with open(Path(CUR_DIR) / "models" / "dpa1.json") as fp: + self.model_json = json.load(fp) + model_dpa2 = self.model_json + ntypes = len(model_dpa2["type_map"]) + dparams = model_dpa2["descriptor"] + dparams["ntypes"] = ntypes + assert dparams["type"] == "se_atten" + dparams.pop("type") + dparams["concat_output_tebd"] = False + dparams["use_tebd_bias"] = True + des = DescrptDPA1( + **dparams, + ).to(env.DEVICE) + target_dict = des.state_dict() + source_dict = paddle.load(str(self.file_model_param)) + type_embd_dict = paddle.load(str(self.file_type_embed)) + target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1( + target_dict, + source_dict, + type_embd_dict, + ) + des.set_state_dict(target_dict) + + coord = self.coord + atype = self.atype + box = self.cell + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + mixed_types=des.mixed_types(), + box=box, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntypes()) + np.testing.assert_allclose( + descriptor.reshape([-1]).numpy(), self.ref_d.numpy(), atol=1e-10, rtol=1e-10 + ) + + dparams["concat_output_tebd"] = True + des = DescrptDPA1( + **dparams, + ).to(env.DEVICE) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + + +def translate_se_atten_and_type_embd_dicts_to_dpa1( + target_dict, + source_dict, + type_embd_dict, +): + all_keys = list(target_dict.keys()) + record = [False for ii in all_keys] + for kk, vv in source_dict.items(): + tk = "se_atten." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = vv + assert len(type_embd_dict.keys()) == 2 + it = iter(type_embd_dict.keys()) + for _ in range(2): + kk = next(it) + tk = "type_embedding." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("se_atten.compress_data.0")] = True + record[all_keys.index("se_atten.compress_info.0")] = True + assert all(record) + return target_dict diff --git a/source/tests/pd/model/test_dpa1.py b/source/tests/pd/model/test_dpa1.py new file mode 100644 index 0000000000..285dd3d4cd --- /dev/null +++ b/source/tests/pd/model/test_dpa1.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import paddle + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 +from deepmd.pd.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, sm, to, tm, prec, ect in itertools.product( + [False, True], # resnet_dt + [False, True], # smooth_type_embedding + [False, True], # type_one_side + ["concat", "strip"], # tebd_input_mode + [ + "float64", + ], # precision + [False, True], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + + # dpa1 new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + type_one_side=to, + tebd_input_mode=tm, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + dd0.se_atten.mean = paddle.to_tensor(davg, dtype=dtype).to( + device=env.DEVICE + ) + dd0.se_atten.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + rd0, _, _, _, _ = dd0( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA1.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, sm, to, tm, ect in itertools.product( + [ + False, + ], # resnet_dt + [ + "float64", + ], # precision + [False, True], # smooth_type_embedding + [ + False, + ], # type_one_side + ["concat", "strip"], # tebd_input_mode + [False, True], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # dpa1 new impl + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + type_one_side=to, + tebd_input_mode=tm, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ) + dd0.se_atten.mean = paddle.to_tensor(davg, dtype=dtype).to( + device=env.DEVICE + ) + dd0.se_atten.dstd = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + # dd1 = DescrptDPA1.deserialize(dd0.serialize()) + model = paddle.jit.to_static(dd0) + # model = paddle.jit.to_static(dd1) diff --git a/source/tests/pd/model/test_env_mat.py b/source/tests/pd/model/test_env_mat.py index 7cbc698264..bbdb7c75a3 100644 --- a/source/tests/pd/model/test_env_mat.py +++ b/source/tests/pd/model/test_env_mat.py @@ -22,7 +22,7 @@ class TestCaseSingleFrameWithNlist: - def setUp(self): + def setUp(self) -> None: # nloc == 3, nall == 4 self.nloc = 3 self.nall = 4 @@ -155,12 +155,12 @@ def setUp(self): # to be merged with the tf test case class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self): + def setUp(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) def test_consistency( self, - ): + ) -> None: rng = np.random.default_rng(GLOBAL_SEED) nf, nloc, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) diff --git a/source/tests/pd/model/test_forward_lower.py b/source/tests/pd/model/test_forward_lower.py index ac8d0f54fc..db6497b605 100644 --- a/source/tests/pd/model/test_forward_lower.py +++ b/source/tests/pd/model/test_forward_lower.py @@ -96,7 +96,7 @@ def test( mixed_types=self.model.mixed_types(), box=cell.unsqueeze(0), ) - extended_spin = decomp.take_along_axis( + extended_spin = paddle.take_along_axis( spin.unsqueeze(0), indices=mapping.unsqueeze(-1).tile((1, 1, 3)), axis=1 ) input_dict = { @@ -146,7 +146,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest): def setUp(self): self.prec = 1e-10 diff --git a/source/tests/pd/model/test_null_input.py b/source/tests/pd/model/test_null_input.py index 9bf0860265..5d67491943 100644 --- a/source/tests/pd/model/test_null_input.py +++ b/source/tests/pd/model/test_null_input.py @@ -22,6 +22,7 @@ eval_model, ) from .test_permutation import ( + model_dpa1, model_se_e2_a, ) @@ -92,3 +93,10 @@ def setUp(self): model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pd/model/test_permutation.py b/source/tests/pd/model/test_permutation.py index 8482ca7ffe..4543348d3b 100644 --- a/source/tests/pd/model/test_permutation.py +++ b/source/tests/pd/model/test_permutation.py @@ -3,6 +3,7 @@ import os import unittest +import numpy as np import paddle from deepmd.pd.model.model import ( @@ -22,7 +23,6 @@ CUR_DIR = os.path.dirname(__file__) dtype = paddle.float64 -import numpy as np model_se_e2_a = { "type_map": ["O", "H", "B"], @@ -344,7 +344,7 @@ class PermutationTest: def test( self, - ): + ) -> None: natoms = 5 generator = paddle.seed(GLOBAL_SEED) cell = paddle.rand([3, 3], dtype=dtype) @@ -395,7 +395,7 @@ def test( class TestEnergyModelSeA(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -403,15 +403,14 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestDOSModelSeA(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dos) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -419,7 +418,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -427,7 +426,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestForceModelDPA2(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True @@ -437,7 +436,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelHybrid(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -445,7 +444,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestForceModelHybrid(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) model_params["fitting_net"]["type"] = "direct_force_ener" self.type_split = True @@ -455,7 +454,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBL(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -463,7 +462,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinSeA(unittest.TestCase, PermutationTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_spin) self.type_split = False self.test_spin = True diff --git a/source/tests/pd/model/test_permutation_denoise.py b/source/tests/pd/model/test_permutation_denoise.py new file mode 100644 index 0000000000..a0de541f0b --- /dev/null +++ b/source/tests/pd/model/test_permutation_denoise.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import paddle + +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.utils import ( + get_generator, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..common import ( + eval_model, +) +from .test_permutation import ( # model_dpau, + model_dpa1, + model_dpa2, + model_hybrid, +) + +dtype = paddle.float64 + +model_dpa1 = copy.deepcopy(model_dpa1) +model_dpa2 = copy.deepcopy(model_dpa2) +model_hybrid = copy.deepcopy(model_hybrid) +model_dpa1["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_dpa1.pop("fitting_net") +model_dpa2["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_dpa2.pop("fitting_net") +model_hybrid["type_map"] = ["O", "H", "B", "MASKED_TOKEN"] +model_hybrid.pop("fitting_net") + + +class PermutationDenoiseTest: + def test( + self, + ) -> None: + generator = get_generator(GLOBAL_SEED) + natoms = 5 + cell = paddle.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * paddle.eye(3).to(env.DEVICE) + coord = paddle.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = paddle.matmul(coord, cell) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + updated_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} + updated_c1, logits1 = eval_model( + self.model, + coord[idx_perm].unsqueeze(0), + cell.unsqueeze(0), + atype[idx_perm], + denoise=True, + ) + ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} + prec = 1e-10 + np.testing.assert_allclose( + ret0["updated_coord"][idx_perm].numpy(), + ret1["updated_coord"].numpy(), + rtol=prec, + atol=prec, + ) + np.testing.assert_allclose( + ret0["logits"][idx_perm].numpy(), + ret1["logits"].numpy(), + rtol=prec, + atol=prec, + ) + + +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model( + model_params, + ).to(env.DEVICE) + + +# @unittest.skip("hybrid not supported at the moment") +# class TestDenoiseModelHybrid(unittest.TestCase, TestPermutationDenoise): +# def setUp(self): +# model_params = copy.deepcopy(model_hybrid_denoise) +# self.type_split = True +# self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/test_rot.py b/source/tests/pd/model/test_rot.py index 4d59117560..85c90dc60f 100644 --- a/source/tests/pd/model/test_rot.py +++ b/source/tests/pd/model/test_rot.py @@ -169,7 +169,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pd/model/test_rot_denoise.py b/source/tests/pd/model/test_rot_denoise.py new file mode 100644 index 0000000000..74d5d41791 --- /dev/null +++ b/source/tests/pd/model/test_rot_denoise.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import paddle + +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..common import ( + eval_model, +) +from .test_permutation_denoise import ( # model_dpa2, + model_dpa1, +) + +dtype = paddle.float64 + + +class RotDenoiseTest: + def test( + self, + ): + generator = paddle.seed(GLOBAL_SEED) + prec = 1e-10 + natoms = 5 + cell = 10.0 * paddle.eye(3, dtype=dtype).to(env.DEVICE) + coord = 2 * paddle.rand([natoms, 3], dtype=dtype).to(device=env.DEVICE) + shift = paddle.to_tensor([4, 4, 4], dtype=dtype).to(env.DEVICE) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + from scipy.stats import ( + special_ortho_group, + ) + + rmat = paddle.to_tensor(special_ortho_group.rvs(3), dtype=dtype).to(env.DEVICE) + + # rotate only coord and shift to the center of cell + coord_rot = paddle.matmul(coord, rmat) + update_c0, logits0 = eval_model( + self.model, + (coord + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + denoise=True, + ) + update_c0 = update_c0 - (coord + shift).unsqueeze(0) + ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} + update_c1, logits1 = eval_model( + self.model, + (coord_rot + shift).unsqueeze(0), + cell.unsqueeze(0), + atype, + denoise=True, + ) + update_c1 = update_c1 - (coord_rot + shift).unsqueeze(0) + ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} + np.testing.assert_allclose( + paddle.matmul(ret0["updated_coord"], rmat).numpy(), + ret1["updated_coord"].numpy(), + rtol=prec, + atol=prec, + ) + np.testing.assert_allclose( + ret0["logits"].numpy(), ret1["logits"].numpy(), rtol=prec, atol=prec + ) + + # rotate coord and cell + paddle.seed(0) + cell = paddle.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * paddle.eye(3).to(env.DEVICE) + coord = paddle.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = paddle.matmul(coord, cell) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + coord_rot = paddle.matmul(coord, rmat) + cell_rot = paddle.matmul(cell, rmat) + update_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + ret0 = {"updated_coord": update_c0.squeeze(0), "logits": logits0.squeeze(0)} + update_c1, logits1 = eval_model( + self.model, + coord_rot.unsqueeze(0), + cell_rot.unsqueeze(0), + atype, + denoise=True, + ) + ret1 = {"updated_coord": update_c1.squeeze(0), "logits": logits1.squeeze(0)} + np.testing.assert_allclose( + ret0["logits"].numpy(), ret1["logits"].numpy(), rtol=prec, atol=prec + ) + np.testing.assert_allclose( + paddle.matmul(ret0["updated_coord"], rmat).numpy(), + ret1["updated_coord"].numpy(), + rtol=prec, + atol=prec, + ) + + +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +# @unittest.skip("hybrid not supported at the moment") +# class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): +# def setUp(self): +# model_params = copy.deepcopy(model_hybrid_denoise) +# self.type_split = True +# self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/test_saveload_dpa1.py b/source/tests/pd/model/test_saveload_dpa1.py new file mode 100644 index 0000000000..54a82e479a --- /dev/null +++ b/source/tests/pd/model/test_saveload_dpa1.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import paddle +from paddle.io import ( + DataLoader, +) + +from deepmd.pd.loss import ( + EnergyStdLoss, +) +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.train.wrapper import ( + ModelWrapper, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.dataloader import ( + BufferedIterator, + DpLoaderSet, +) +from deepmd.pd.utils.stat import ( + make_stat_input, +) +from deepmd.tf.common import ( + expand_sys_str, +) + + +def get_dataset(config): + model_config = config["model"] + rcut = model_config["descriptor"]["rcut"] + sel = model_config["descriptor"]["sel"] + systems = config["training"]["validation_data"]["systems"] + if isinstance(systems, str): + systems = expand_sys_str(systems) + batch_size = config["training"]["training_data"]["batch_size"] + type_map = model_config["type_map"] + + dataset = DpLoaderSet(systems, batch_size, type_map) + data_stat_nbatch = model_config.get("data_stat_nbatch", 10) + sampled = make_stat_input(dataset.systems, dataset.dataloaders, data_stat_nbatch) + return dataset, sampled + + +class TestSaveLoadDPA1(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as fin: + self.config = json.load(fin) + self.config["loss"]["starter_learning_rate"] = self.config["learning_rate"][ + "start_lr" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.dataset, self.sampled = get_dataset(self.config) + self.training_dataloader = DataLoader( + self.dataset, + batch_sampler=paddle.io.BatchSampler( + sampler=paddle.io.RandomSampler(self.dataset), + drop_last=False, + ), + num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1 + collate_fn=lambda x: x[0], + ) + device = paddle.get_device() + paddle.set_device("cpu") + self.training_data = BufferedIterator(iter(self.training_dataloader)) + paddle.set_device(device) + self.loss = EnergyStdLoss(**self.config["loss"]) + self.cur_lr = 1 + self.task_key = "Default" + self.input_dict, self.label_dict = self.get_data() + self.start_lr = self.config["learning_rate"]["start_lr"] + + def get_model_result(self, read=False, model_file="tmp_model.pd"): + wrapper = self.create_wrapper(read) + optimizer = paddle.optimizer.Adam( + learning_rate=self.start_lr, parameters=wrapper.parameters() + ) + optimizer.clear_grad() + if read: + wrapper.set_state_dict(paddle.load(model_file)) + os.remove(model_file) + else: + paddle.save(wrapper.state_dict(), model_file) + result = wrapper( + **self.input_dict, + cur_lr=self.cur_lr, + label=self.label_dict, + task_key=self.task_key, + )[0] + return result + + def create_wrapper(self, read: bool): + model_config = copy.deepcopy(self.config["model"]) + model_config["resuming"] = read + model_config["stat_file_dir"] = "stat_files" + model_config["stat_file"] = "stat.hdf5" + model_config["stat_file_path"] = os.path.join( + model_config["stat_file_dir"], model_config["stat_file"] + ) + model = get_model(model_config).to(env.DEVICE) + return ModelWrapper(model, self.loss) + + def get_data(self): + try: + batch_data = next(iter(self.training_data)) + except StopIteration: + # Refresh the status of the dataloader to start from a new epoch + self.training_data = BufferedIterator(iter(self.training_dataloader)) + batch_data = next(iter(self.training_data)) + input_dict = {} + for item in ["coord", "atype", "box"]: + if item in batch_data: + input_dict[item] = batch_data[item].to(env.DEVICE) + else: + input_dict[item] = None + label_dict = {} + for item in ["energy", "force", "virial"]: + if item in batch_data: + label_dict[item] = batch_data[item].to(env.DEVICE) + return input_dict, label_dict + + def test_saveload(self): + result1 = self.get_model_result() + result2 = self.get_model_result(read=True) + for item in result1: + np.testing.assert_allclose(result1[item].numpy(), result2[item].numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/test_smooth.py b/source/tests/pd/model/test_smooth.py index 7f77a6f188..cc50043ad8 100644 --- a/source/tests/pd/model/test_smooth.py +++ b/source/tests/pd/model/test_smooth.py @@ -19,6 +19,7 @@ eval_model, ) from .test_permutation import ( # model_dpau, + model_dpa1, model_se_e2_a, ) @@ -153,6 +154,41 @@ def setUp(self): self.epsilon, self.aprec = None, None +class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + +class TestEnergyModelDPA1Excl1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + +class TestEnergyModelDPA1Excl12(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1], [0, 2]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pd/model/test_trans.py b/source/tests/pd/model/test_trans.py index f69d2f5b83..3fae49d598 100644 --- a/source/tests/pd/model/test_trans.py +++ b/source/tests/pd/model/test_trans.py @@ -103,7 +103,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) diff --git a/source/tests/pd/model/test_trans_denoise.py b/source/tests/pd/model/test_trans_denoise.py new file mode 100644 index 0000000000..8317d4d2ae --- /dev/null +++ b/source/tests/pd/model/test_trans_denoise.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import paddle + +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..common import ( + eval_model, +) +from .test_permutation_denoise import ( + model_dpa1, + model_dpa2, + model_hybrid, +) + +dtype = paddle.float64 + + +class TransDenoiseTest: + def test( + self, + ): + natoms = 5 + generator = paddle.seed(GLOBAL_SEED) + cell = paddle.rand([3, 3], dtype=dtype).to(env.DEVICE) + cell = (cell + cell.T) + 5.0 * paddle.eye(3).to(env.DEVICE) + coord = paddle.rand([natoms, 3], dtype=dtype).to(env.DEVICE) + coord = paddle.matmul(coord, cell) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + shift = (paddle.rand([3], dtype=dtype) - 0.5).to(env.DEVICE) * 2.0 + coord_s = paddle.matmul( + paddle.remainder( + paddle.matmul(coord + shift, paddle.linalg.inv(cell)), 1.0 + ), + cell, + ) + updated_c0, logits0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + updated_c0 = updated_c0 - coord.unsqueeze(0) + ret0 = {"updated_coord": updated_c0.squeeze(0), "logits": logits0.squeeze(0)} + updated_c1, logits1 = eval_model( + self.model, coord_s.unsqueeze(0), cell.unsqueeze(0), atype, denoise=True + ) + updated_c1 = updated_c1 - coord_s.unsqueeze(0) + ret1 = {"updated_coord": updated_c1.squeeze(0), "logits": logits1.squeeze(0)} + prec = 1e-10 + np.testing.assert_allclose( + ret0["updated_coord"].numpy(), + ret1["updated_coord"].numpy(), + rtol=prec, + atol=prec, + ) + np.testing.assert_allclose( + ret0["logits"].numpy(), ret1["logits"].numpy(), rtol=prec, atol=prec + ) + + +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA1(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA2(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestDenoiseModelHybrid(unittest.TestCase, TransDenoiseTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/water/multitask_sharefit.json b/source/tests/pd/model/water/multitask_sharefit.json new file mode 100644 index 0000000000..246b5992f7 --- /dev/null +++ b/source/tests/pd/model/water/multitask_sharefit.json @@ -0,0 +1,134 @@ +{ + "model": { + "shared_dict": { + "my_type_map": [ + "O", + "H", + "B" + ], + "my_descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92, + 4 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1, + "_comment": " that's all" + }, + "my_fitting": { + "dim_case_embd": 2, + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "_comment": " that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1 + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1 + } + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.0002, + "decay_rate": 0.98, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + } + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5 + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1.hdf5", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + }, + "model_2": { + "stat_file": "./stat_files/model_2.hdf5", + "training_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "pt/water/data/data_0" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 100000, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + "_comment": "that's all" + } +} diff --git a/source/tests/pd/test_decomp.py b/source/tests/pd/test_decomp.py index d8439ad994..c554083bda 100644 --- a/source/tests/pd/test_decomp.py +++ b/source/tests/pd/test_decomp.py @@ -17,50 +17,6 @@ class TestDecomp(unittest.TestCase): def setUp(self): paddle.seed(GLOBAL_SEED) - def test_softmax_decomp(self): - raw_api = paddle.nn.functional.softmax - decomp_api = decomp.softmax - - raw_input = paddle.randn([100, 100], "float32") - raw_output = raw_api(raw_input) - decomp_output = decomp_api(raw_input) - - np.testing.assert_allclose( - raw_output.numpy(), - decomp_output.numpy(), - 1e-6, - 1e-8, - ) - - def test_norm_decomp(self): - raw_api = paddle.linalg.norm - decomp_api = decomp.norm - - raw_input = paddle.randn([100, 100], "float32") - raw_output = raw_api(raw_input, p=2, axis=-1) - decomp_output = decomp_api(raw_input, p=2, axis=-1) - - np.testing.assert_allclose( - raw_output.numpy(), - decomp_output.numpy(), - 1e-5, - 1e-8, - ) - - def test_take_along_axis_decomp(self): - raw_api = paddle.take_along_axis - decomp_api = decomp.take_along_axis - - raw_input = paddle.randn([100, 100], "float32") - raw_indices = paddle.randint(0, 100, [100, 2]) - raw_output = raw_api(raw_input, raw_indices, axis=-1) - decomp_output = decomp_api(raw_input, raw_indices, axis=-1) - - np.testing.assert_equal( - raw_output.numpy(), - decomp_output.numpy(), - ) - def test_scatter_reduce_decomp(self): raw_api = paddle.put_along_axis decomp_api = decomp.scatter_reduce @@ -112,20 +68,3 @@ def test_masked_add_(self): raw_output.numpy(), raw_input.numpy(), # inplace ) - - def test_normalize_decomp(self): - raw_api = paddle.nn.functional.normalize - decomp_api = decomp.normalize_decomp - - raw_input = paddle.randn([100, 100], "float32") - axis = -1 - - raw_output = raw_api(raw_input, p=2, axis=axis) - decomp_output = decomp_api(raw_input, p=2, axis=axis) - - np.testing.assert_allclose( - raw_output.numpy(), - decomp_output.numpy(), # inplace - 1e-5, - 1e-8, - ) diff --git a/source/tests/pd/test_finetune.py b/source/tests/pd/test_finetune.py index 2c6cca83aa..f82f7a8cd0 100644 --- a/source/tests/pd/test_finetune.py +++ b/source/tests/pd/test_finetune.py @@ -341,7 +341,6 @@ def setUp(self): self.testkey = "dos" -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA1(FinetuneTest, unittest.TestCase): def setUp(self): input_json = str(Path(__file__).parent / "water/se_atten.json") diff --git a/source/tests/pd/test_multitask.py b/source/tests/pd/test_multitask.py index 65210d07b3..d59990dcca 100644 --- a/source/tests/pd/test_multitask.py +++ b/source/tests/pd/test_multitask.py @@ -29,23 +29,24 @@ ) from .model.test_permutation import ( + model_dpa1, model_se_e2_a, ) -def setUpModule(): +def setUpModule() -> None: global multitask_template multitask_template_json = str(Path(__file__).parent / "water/multitask.json") with open(multitask_template_json) as f: multitask_template = json.load(f) -@unittest.skip("Skip until solving cuda error 709 in jit.save") class MultiTaskTrainTest: - def test_multitask_train(self): + def test_multitask_train(self) -> None: # test multitask training self.config = update_deepmd_input(self.config, warning=True) self.config = normalize(self.config, multi_task=True) + self.share_fitting = getattr(self, "share_fitting", False) trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) trainer.run() # check model keys @@ -60,7 +61,12 @@ def test_multitask_train(self): self.assertIn(state_key.replace("model_1", "model_2"), multi_state_dict) if "model_2" in state_key: self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict) - if "model_1.descriptor" in state_key: + if ("model_1.atomic_model.descriptor" in state_key) or ( + self.share_fitting + and "model_1.atomic_model.fitting_net" in state_key + and "fitting_net.bias_atom_e" not in state_key + and "fitting_net.case_embd" not in state_key + ): np.testing.assert_allclose( multi_state_dict[state_key].numpy(), multi_state_dict[state_key.replace("model_1", "model_2")].numpy(), @@ -172,7 +178,7 @@ def test_multitask_train(self): trainer_finetune.run() self.tearDown() - def tearDown(self): + def tearDown(self) -> None: for f in os.listdir("."): if f.startswith("model") and f.endswith(".pd"): os.remove(f) @@ -182,9 +188,8 @@ def tearDown(self): shutil.rmtree(f) -@unittest.skip("Skip until solving cuda error 709 in jit.save") class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest): - def setUp(self): + def setUp(self) -> None: multitask_se_e2_a = deepcopy(multitask_template) multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ "descriptor" @@ -222,5 +227,44 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_DPA1 = deepcopy(multitask_template) + multitask_DPA1["model"]["shared_dict"]["my_descriptor"] = model_dpa1[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA1" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA1 + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index d4e7309a65..c3d65c09df 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -15,17 +15,21 @@ from deepmd.pd.entrypoints.main import ( get_trainer, ) +from deepmd.pd.utils.env import ( + enable_prim, +) from deepmd.pd.utils.finetune import ( get_finetune_rules, ) from .model.test_permutation import ( + model_dpa1, model_se_e2_a, ) class DPTrainTest: - def test_dp_train(self): + def test_dp_train(self) -> None: # test training from scratch trainer = get_trainer(deepcopy(self.config)) trainer.run() @@ -95,7 +99,7 @@ def test_dp_train(self): trainer_finetune_empty.run() trainer_finetune_random.run() - def test_trainable(self): + def test_trainable(self) -> None: fix_params = deepcopy(self.config) fix_params["model"]["descriptor"]["trainable"] = False fix_params["model"]["fitting_net"]["trainable"] = False @@ -124,7 +128,7 @@ def test_trainable(self): model_dict_after_training[key].numpy(), ) - def tearDown(self): + def tearDown(self) -> None: for f in os.listdir("."): if f.startswith("model") and f.endswith(".pd"): os.remove(f) @@ -135,7 +139,7 @@ def tearDown(self): class TestEnergyModelSeA(unittest.TestCase, DPTrainTest): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -145,6 +149,9 @@ def setUp(self): self.config["model"] = deepcopy(model_se_e2_a) self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 + # import paddle + enable_prim(True) + # assert paddle.framework.core._is_eager_prim_enabled() def tearDown(self) -> None: DPTrainTest.tearDown(self) @@ -153,7 +160,7 @@ def tearDown(self) -> None: class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -172,5 +179,21 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestEnergyModelDPA1(unittest.TestCase, DPTrainTest): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa1) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main()