diff --git a/deepmd/pd/model/descriptor/__init__.py b/deepmd/pd/model/descriptor/__init__.py index 7eaa0df85b..cee9dbf226 100644 --- a/deepmd/pd/model/descriptor/__init__.py +++ b/deepmd/pd/model/descriptor/__init__.py @@ -9,20 +9,34 @@ DescrptBlockSeAtten, DescrptDPA1, ) +from .dpa2 import ( + DescrptDPA2, +) from .env_mat import ( prod_env_mat, ) +from .repformers import ( + DescrptBlockRepformers, +) from .se_a import ( DescrptBlockSeA, DescrptSeA, ) +from .se_t_tebd import ( + DescrptBlockSeTTebd, + DescrptSeTTebd, +) __all__ = [ "BaseDescriptor", "DescriptorBlock", + "DescrptBlockRepformers", "DescrptBlockSeA", "DescrptBlockSeAtten", + "DescrptBlockSeTTebd", "DescrptDPA1", + "DescrptDPA2", "DescrptSeA", + "DescrptSeTTebd", "prod_env_mat", ] diff --git a/deepmd/pd/model/descriptor/dpa2.py b/deepmd/pd/model/descriptor/dpa2.py new file mode 100644 index 0000000000..e0ec0a501d --- /dev/null +++ b/deepmd/pd/model/descriptor/dpa2.py @@ -0,0 +1,897 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.network.mlp import ( + Identity, + MLPLayer, + NetworkCollection, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.nlist import ( + build_multiple_neighbor_list, + get_multiple_nlist_key, +) +from deepmd.pd.utils.update_sel import ( + UpdateSel, +) +from deepmd.pd.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repformer_layer import ( + RepformerLayer, +) +from .repformers import ( + DescrptBlockRepformers, +) +from .se_atten import ( + DescrptBlockSeAtten, +) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(BaseDescriptor, paddle.nn.Layer): + def __init__( + self, + ntypes: int, + # args for repinit + repinit: Union[RepinitArgs, dict], + # args for repformer + repformer: Union[RepformerArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = True, + precision: str = "float64", + smooth: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + add_tebd_to_repinit_out: bool = False, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-2 descriptor. see https://arxiv.org/abs/2312.15492. + + Parameters + ---------- + repinit : Union[RepinitArgs, dict] + The arguments used to initialize the repinit block, see docstr in `RepinitArgs` for details information. + repformer : Union[RepformerArgs, dict] + The arguments used to initialize the repformer block, see docstr in `RepformerArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + add_tebd_to_repinit_out : bool, optional + Whether to add type embedding to the output representation from repinit before inputting it into repformer. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: paddle.Tensor + the descriptor of shape nb x nloc x g1_dim. + invariant single-atom representation. + g2: paddle.Tensor + invariant pair-atom representation. + h2: paddle.Tensor + equivariant pair-atom representation. + rot_mat: paddle.Tensor + rotation matrix for equivariant fittings + sw: paddle.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repinit_args = init_subclass_params(repinit, RepinitArgs) + self.repformer_args = init_subclass_params(repformer, RepformerArgs) + self.tebd_input_mode = self.repinit_args.tebd_input_mode + + self.repinit = DescrptBlockSeAtten( + self.repinit_args.rcut, + self.repinit_args.rcut_smth, + self.repinit_args.nsel, + ntypes, + attn_layer=0, + neuron=self.repinit_args.neuron, + axis_neuron=self.repinit_args.axis_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + type_one_side=self.repinit_args.type_one_side, + seed=child_seed(seed, 0), + ) + self.use_three_body = self.repinit_args.use_three_body + if self.use_three_body: + self.repinit_three_body = DescrptBlockSeTTebd( + self.repinit_args.three_body_rcut, + self.repinit_args.three_body_rcut_smth, + self.repinit_args.three_body_sel, + ntypes, + neuron=self.repinit_args.three_body_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + seed=child_seed(seed, 5), + ) + else: + self.repinit_three_body = None + self.repformers = DescrptBlockRepformers( + self.repformer_args.rcut, + self.repformer_args.rcut_smth, + self.repformer_args.nsel, + ntypes, + nlayers=self.repformer_args.nlayers, + g1_dim=self.repformer_args.g1_dim, + g2_dim=self.repformer_args.g2_dim, + axis_neuron=self.repformer_args.axis_neuron, + direct_dist=self.repformer_args.direct_dist, + update_g1_has_conv=self.repformer_args.update_g1_has_conv, + update_g1_has_drrd=self.repformer_args.update_g1_has_drrd, + update_g1_has_grrg=self.repformer_args.update_g1_has_grrg, + update_g1_has_attn=self.repformer_args.update_g1_has_attn, + update_g2_has_g1g1=self.repformer_args.update_g2_has_g1g1, + update_g2_has_attn=self.repformer_args.update_g2_has_attn, + update_h2=self.repformer_args.update_h2, + attn1_hidden=self.repformer_args.attn1_hidden, + attn1_nhead=self.repformer_args.attn1_nhead, + attn2_hidden=self.repformer_args.attn2_hidden, + attn2_nhead=self.repformer_args.attn2_nhead, + attn2_has_gate=self.repformer_args.attn2_has_gate, + activation_function=self.repformer_args.activation_function, + update_style=self.repformer_args.update_style, + update_residual=self.repformer_args.update_residual, + update_residual_init=self.repformer_args.update_residual_init, + set_davg_zero=self.repformer_args.set_davg_zero, + smooth=smooth, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + trainable_ln=self.repformer_args.trainable_ln, + ln_eps=self.repformer_args.ln_eps, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, + seed=child_seed(seed, 1), + ) + self.rcsl_list = [ + (self.repformers.get_rcut(), self.repformers.get_nsel()), + (self.repinit.get_rcut(), self.repinit.get_nsel()), + ] + if self.use_three_body: + self.rcsl_list.append( + (self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel()) + ) + self.rcsl_list.sort() + for ii in range(1, len(self.rcsl_list)): + assert ( + self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1] + ), "rcut and sel are not in the same order" + self.rcut_list = [ii[0] for ii in self.rcsl_list] + self.nsel_list = [ii[1] for ii in self.rcsl_list] + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.type_embedding = TypeEmbedNet( + ntypes, + self.repinit_args.tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.smooth = smooth + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + self.add_tebd_to_repinit_out = add_tebd_to_repinit_out + + self.repinit_out_dim = self.repinit.dim_out + if self.repinit_args.use_three_body: + assert self.repinit_three_body is not None + self.repinit_out_dim += self.repinit_three_body.dim_out + + if self.repinit_out_dim == self.repformers.dim_in: + self.g1_shape_tranform = Identity() + else: + self.g1_shape_tranform = MLPLayer( + self.repinit_out_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + init="glorot", + seed=child_seed(seed, 3), + ) + self.tebd_transform = None + if self.add_tebd_to_repinit_out: + self.tebd_transform = MLPLayer( + self.repinit_args.tebd_dim, + self.repformers.dim_in, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + assert self.repinit.rcut > self.repformers.rcut + assert self.repinit.sel[0] > self.repformers.sel[0] + + self.tebd_dim = self.repinit_args.tebd_dim + self.rcut = self.repinit.get_rcut() + self.rcut_smth = self.repinit.get_rcut_smth() + self.ntypes = ntypes + self.sel = self.repinit.sel + # set trainable + for param in self.parameters(): + param.stop_gradient = not trainable + self.compress = False + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repformers.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repformers.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any( + [self.repinit.has_message_passing(), self.repformers.has_message_passing()] + ) + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + # the env_protection of repinit is the same as that of the repformer + return self.repinit.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA2 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repinit and repformers + if shared_level == 0: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + self.repinit.share_params(base_class.repinit, 0, resume=resume) + if self.use_three_body: + self.repinit_three_body.share_params( + base_class.repinit_three_body, 0, resume=resume + ) + self._sub_layers["g1_shape_tranform"] = base_class._sub_layers[ + "g1_shape_tranform" + ] + self.repformers.share_params(base_class.repformers, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repinit = self.repinit + repformers = self.repformers + repinit_three_body = self.repinit_three_body + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repinit, + type_map, + des_with_stat=model_with_new_type_stat.repinit + if model_with_new_type_stat is not None + else None, + ) + extend_descrpt_stat( + repformers, + type_map, + des_with_stat=model_with_new_type_stat.repformers + if model_with_new_type_stat is not None + else None, + ) + if self.use_three_body: + extend_descrpt_stat( + repinit_three_body, + type_map, + des_with_stat=model_with_new_type_stat.repinit_three_body + if model_with_new_type_stat is not None + else None, + ) + repinit.ntypes = self.ntypes + repformers.ntypes = self.ntypes + repinit.reinit_exclude(self.exclude_types) + repformers.reinit_exclude(self.exclude_types) + repinit["davg"] = repinit["davg"][remap_index] + repinit["dstd"] = repinit["dstd"][remap_index] + repformers["davg"] = repformers["davg"][remap_index] + repformers["dstd"] = repformers["dstd"][remap_index] + if self.use_three_body: + repinit_three_body.ntypes = self.ntypes + repinit_three_body.reinit_exclude(self.exclude_types) + repinit_three_body["davg"] = repinit_three_body["davg"][remap_index] + repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + descrpt_list = [self.repinit, self.repformers] + if self.use_three_body: + descrpt_list.append(self.repinit_three_body) + for ii, descrpt in enumerate(descrpt_list): + descrpt.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: list[paddle.Tensor], + stddev: list[paddle.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repinit, self.repformers] + if self.use_three_body: + descrpt_list.append(self.repinit_three_body) + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev( + self, + ) -> tuple[list[paddle.Tensor], list[paddle.Tensor]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repinit.mean, self.repformers.mean] + stddev_list = [ + self.repinit.stddev, + self.repformers.stddev, + ] + if self.use_three_body: + mean_list.append(self.repinit_three_body.mean) + stddev_list.append(self.repinit_three_body.stddev) + return mean_list, stddev_list + + def serialize(self) -> dict: + repinit = self.repinit + repformers = self.repformers + repinit_three_body = self.repinit_three_body + data = { + "@class": "Descriptor", + "type": "dpa2", + "@version": 3, + "ntypes": self.ntypes, + "repinit_args": self.repinit_args.serialize(), + "repformer_args": self.repformer_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "precision": self.precision, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "add_tebd_to_repinit_out": self.add_tebd_to_repinit_out, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.embedding.serialize(), + "g1_shape_tranform": self.g1_shape_tranform.serialize(), + } + if self.add_tebd_to_repinit_out: + data.update( + { + "tebd_transform": self.tebd_transform.serialize(), + } + ) + repinit_variable = { + "embeddings": repinit.filter_layers.serialize(), + "env_mat": DPEnvMat(repinit.rcut, repinit.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), + }, + } + if repinit.tebd_input_mode in ["strip"]: + repinit_variable.update( + {"embeddings_strip": repinit.filter_layers_strip.serialize()} + ) + repformers_variable = { + "g2_embd": repformers.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in repformers.layers], + "env_mat": DPEnvMat(repformers.rcut, repformers.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), + }, + } + data.update( + { + "repinit_variable": repinit_variable, + "repformers_variable": repformers_variable, + } + ) + if self.use_three_body: + repinit_three_body_variable = { + "embeddings": repinit_three_body.filter_layers.serialize(), + "env_mat": DPEnvMat( + repinit_three_body.rcut, repinit_three_body.rcut_smth + ).serialize(), + "@variables": { + "davg": to_numpy_array(repinit_three_body["davg"]), + "dstd": to_numpy_array(repinit_three_body["dstd"]), + }, + } + if repinit_three_body.tebd_input_mode in ["strip"]: + repinit_three_body_variable.update( + { + "embeddings_strip": repinit_three_body.filter_layers_strip.serialize() + } + ) + data.update( + { + "repinit_three_body_variable": repinit_three_body_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA2": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 3, 1) + data.pop("@class") + data.pop("type") + repinit_variable = data.pop("repinit_variable").copy() + repformers_variable = data.pop("repformers_variable").copy() + repinit_three_body_variable = ( + data.pop("repinit_three_body_variable").copy() + if "repinit_three_body_variable" in data + else None + ) + type_embedding = data.pop("type_embedding") + g1_shape_tranform = data.pop("g1_shape_tranform") + tebd_transform = data.pop("tebd_transform", None) + add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False + data["repinit"] = RepinitArgs(**data.pop("repinit_args")) + data["repformer"] = RepformerArgs(**data.pop("repformer_args")) + # compat with version 1 + if "use_tebd_bias" not in data: + data["use_tebd_bias"] = True + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + if add_tebd_to_repinit_out: + assert isinstance(tebd_transform, dict) + obj.tebd_transform = MLPLayer.deserialize(tebd_transform) + if obj.repinit.dim_out != obj.repformers.dim_in: + obj.g1_shape_tranform = MLPLayer.deserialize(g1_shape_tranform) + + def t_cvt(xx): + return paddle.to_tensor(xx, dtype=obj.repinit.prec, place=env.DEVICE) + + # deserialize repinit + statistic_repinit = repinit_variable.pop("@variables") + env_mat = repinit_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit.filter_layers = NetworkCollection.deserialize( + repinit_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit.filter_layers_strip = NetworkCollection.deserialize( + repinit_variable.pop("embeddings_strip") + ) + obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) + obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) + + if data["repinit"].use_three_body: + # deserialize repinit_three_body + statistic_repinit_three_body = repinit_three_body_variable.pop("@variables") + env_mat = repinit_three_body_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit_three_body.filter_layers = NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit_three_body.filter_layers_strip = ( + NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings_strip") + ) + ) + obj.repinit_three_body["davg"] = t_cvt(statistic_repinit_three_body["davg"]) + obj.repinit_three_body["dstd"] = t_cvt(statistic_repinit_three_body["dstd"]) + + # deserialize repformers + statistic_repformers = repformers_variable.pop("@variables") + env_mat = repformers_variable.pop("env_mat") + repformer_layers = repformers_variable.pop("repformer_layers") + obj.repformers.g2_embd = MLPLayer.deserialize( + repformers_variable.pop("g2_embd") + ) + obj.repformers["davg"] = t_cvt(statistic_repformers["davg"]) + obj.repformers["dstd"] = t_cvt(statistic_repformers["dstd"]) + obj.repformers.layers = paddle.nn.LayerList( + [RepformerLayer.deserialize(layer) for layer in repformer_layers] + ) + return obj + + def forward( + self, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + nlist: paddle.Tensor, + mapping: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + + use_three_body = self.use_three_body + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + # nlists + nlist_dict = build_multiple_neighbor_list( + extended_coord.detach(), + nlist, + self.rcut_list, + self.nsel_list, + ) + # repinit + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.place) + else: + type_embedding = None + g1, _, _, _, _ = self.repinit( + nlist_dict[ + get_multiple_nlist_key(self.repinit.get_rcut(), self.repinit.get_nsel()) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + type_embedding, + ) + if use_three_body: + assert self.repinit_three_body is not None + g1_three_body, __, __, __, __ = self.repinit_three_body( + nlist_dict[ + get_multiple_nlist_key( + self.repinit_three_body.get_rcut(), + self.repinit_three_body.get_nsel(), + ) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + type_embedding, + ) + g1 = paddle.concat([g1, g1_three_body], axis=-1) + # linear to change shape + g1 = self.g1_shape_tranform(g1) + if self.add_tebd_to_repinit_out: + assert self.tebd_transform is not None + g1 = g1 + self.tebd_transform(g1_inp) + # mapping g1 + if comm_dict is None: + assert mapping is not None + mapping_ext = ( + mapping.reshape([nframes, nall]) + .unsqueeze(-1) + .expand([-1, -1, g1.shape[-1]]) + ) + g1_ext = paddle.take_along_axis(g1, mapping_ext, 1) + g1 = g1_ext + # repformer + g1, g2, h2, rot_mat, sw = self.repformers( + nlist_dict[ + get_multiple_nlist_key( + self.repformers.get_rcut(), self.repformers.get_nsel() + ) + ], + extended_coord, + extended_atype, + g1, + mapping, + comm_dict=comm_dict, + ) + if self.concat_output_tebd: + g1 = paddle.concat([g1, g1_inp], axis=-1) + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + g2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repinit_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repinit"]["rcut"], + local_jdata_cpy["repinit"]["nsel"], + True, + ) + local_jdata_cpy["repinit"]["nsel"] = repinit_sel[0] + min_nbor_dist, repinit_three_body_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repinit"]["three_body_rcut"], + local_jdata_cpy["repinit"]["three_body_sel"], + True, + ) + local_jdata_cpy["repinit"]["three_body_sel"] = repinit_three_body_sel[0] + min_nbor_dist, repformer_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repformer"]["rcut"], + local_jdata_cpy["repformer"]["nsel"], + True, + ) + local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0] + return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + raise NotImplementedError("enable_compression is not implemented yet") diff --git a/deepmd/pd/model/descriptor/repformer_layer.py b/deepmd/pd/model/descriptor/repformer_layer.py new file mode 100644 index 0000000000..a09c5cbe17 --- /dev/null +++ b/deepmd/pd/model/descriptor/repformer_layer.py @@ -0,0 +1,1484 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import paddle +import paddle.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.network.init import ( + constant_, + normal_, +) +from deepmd.pd.model.network.layernorm import ( + LayerNorm, +) +from deepmd.pd.model.network.mlp import ( + MLPLayer, +) +from deepmd.pd.utils import ( + decomp, + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.utils import ( + ActivationFn, + get_generator, + to_numpy_array, + to_paddle_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, +) -> paddle.Tensor: + r""" + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + seed : int, optional + Random seed for parameter initialization. + """ + random_generator = get_generator(seed) + residual = paddle.create_parameter( + [_dim], + dtype=PRECISION_DICT[precision], + default_initializer=nn.initializer.Constant(0), + ).to(device=env.DEVICE) + residual.stop_gradient = not trainable + if _mode == "norm": + normal_(residual.data, std=_scale, generator=random_generator) + elif _mode == "const": + constant_(residual.data, val=_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + +# common ops +def _make_nei_g1( + g1_ext: paddle.Tensor, + nlist: paddle.Tensor, +) -> paddle.Tensor: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape nb x nall x ng1. + nlist + Neighbor list, with shape nb x nloc x nnei. + + Returns + ------- + gg1: paddle.Tensor + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + + """ + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, ng1]) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = paddle.take_along_axis(g1_ext, axis=1, indices=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.reshape([nb, nloc, nnei, ng1]) + return gg1 + + +def _apply_nlist_mask( + gg: paddle.Tensor, + nlist_mask: paddle.Tensor, +) -> paddle.Tensor: + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: paddle.Tensor, sw: paddle.Tensor) -> paddle.Tensor: + """ + Apply switch function to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +class Atten2Map(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + has_gate: bool = False, # apply gate to attn map + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ): + """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapqk = MLPLayer( + input_dim, + hidden_dim * 2 * head_num, + bias=False, + precision=precision, + seed=seed, + ) + self.has_gate = has_gate + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def forward( + self, + g2: paddle.Tensor, # nb x nloc x nnei x ng2 + h2: paddle.Tensor, # nb x nloc x nnei x 3 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + ( + nb, + nloc, + nnei, + _, + ) = g2.shape + nd, nh = self.hidden_dim, self.head_num + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).reshape([nb, nloc, nnei, nd, nh * 2]) + # nb x nloc x (nh x 2) x nnei x nd + g2qk = paddle.transpose(g2qk, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd + g2q, g2k = paddle.split(g2qk, decomp.sec(g2qk.shape[2], nh), axis=2) + # g2q = paddle.nn.functional.normalize(g2q, axis=-1) + # g2k = paddle.nn.functional.normalize(g2k, axis=-1) + # nb x nloc x nh x nnei x nnei + attnw = paddle.matmul(g2q, paddle.transpose(g2k, [0, 1, 2, 4, 3])) / nd**0.5 + if self.has_gate: + gate = paddle.matmul(h2, paddle.transpose(h2, [0, 1, 3, 2])).unsqueeze(-3) + attnw = attnw * gate + # mask the attenmap, nb x nloc x 1 x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) + # mask the attenmap, nb x nloc x 1 x nnei x 1 + attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ + :, :, None, None, : + ] - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = paddle.nn.functional.softmax(attnw, axis=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + # nb x nloc x nh x nnei x nnei + attnw = attnw.masked_fill( + attnw_mask_c, + 0.0, + ) + if self.smooth: + attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] + # nb x nloc x nnei x nnei + h2h2t = paddle.matmul(h2, paddle.transpose(h2, [0, 1, 3, 2])) / 3.0**0.5 + # nb x nloc x nh x nnei x nnei + ret = attnw * h2h2t[:, :, None, :, :] + # ret = paddle.nn.functional.softmax(g2qk, axis=-1) + # nb x nloc x nnei x nnei x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2Map", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "has_gate": self.has_gate, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapqk": self.mapqk.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2Map": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapqk = data.pop("mapqk") + obj = cls(**data) + obj.mapqk = MLPLayer.deserialize(mapqk) + return obj + + +class Atten2MultiHeadApply(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.mapv = MLPLayer( + input_dim, + input_dim * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 0), + ) + self.head_map = MLPLayer( + input_dim * head_num, + input_dim, + precision=precision, + seed=child_seed(seed, 1), + ) + self.precision = precision + + def forward( + self, + AA: paddle.Tensor, # nf x nloc x nnei x nnei x nh + g2: paddle.Tensor, # nf x nloc x nnei x ng2 + ) -> paddle.Tensor: + nf, nloc, nnei, ng2 = g2.shape + nh = self.head_num + # nf x nloc x nnei x ng2 x nh + g2v = self.mapv(g2).reshape([nf, nloc, nnei, ng2, nh]) + # nf x nloc x nh x nnei x ng2 + g2v = paddle.transpose(g2v, (0, 1, 4, 2, 3)) + # g2v = paddle.nn.functional.normalize(g2v, axis=-1) + # nf x nloc x nh x nnei x nnei + AA = paddle.transpose(AA, (0, 1, 4, 2, 3)) + # nf x nloc x nh x nnei x ng2 + ret = paddle.matmul(AA, g2v) + # nf x nloc x nnei x ng2 x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)).reshape( + [nf, nloc, nnei, (ng2 * nh)] + ) + # nf x nloc x nnei x ng2 + return self.head_map(ret) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2MultiHeadApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "mapv": self.mapv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2MultiHeadApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapv = data.pop("mapv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapv = MLPLayer.deserialize(mapv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class Atten2EquiVarApply(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + head_num: int, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.head_num = head_num + self.head_map = MLPLayer( + head_num, 1, bias=False, precision=precision, seed=seed + ) + self.precision = precision + + def forward( + self, + AA: paddle.Tensor, # nf x nloc x nnei x nnei x nh + h2: paddle.Tensor, # nf x nloc x nnei x 3 + ) -> paddle.Tensor: + nf, nloc, nnei, _ = h2.shape + nh = self.head_num + # nf x nloc x nh x nnei x nnei + AA = paddle.transpose(AA, (0, 1, 4, 2, 3)) + h2m = paddle.unsqueeze(h2, axis=2) + # nf x nloc x nh x nnei x 3 + h2m = paddle.tile(h2m, [1, 1, nh, 1, 1]) + # nf x nloc x nh x nnei x 3 + ret = paddle.matmul(AA, h2m) + # nf x nloc x nnei x 3 x nh + ret = paddle.transpose(ret, (0, 1, 3, 4, 2)).reshape([nf, nloc, nnei, 3, nh]) + # nf x nloc x nnei x 3 + return paddle.squeeze(self.head_map(ret), axis=-1) + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "Atten2EquiVarApply", + "@version": 1, + "input_dim": self.input_dim, + "head_num": self.head_num, + "precision": self.precision, + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "Atten2EquiVarApply": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + head_map = data.pop("head_map") + obj = cls(**data) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class LocalAtten(paddle.nn.Layer): + def __init__( + self, + input_dim: int, + hidden_dim: int, + head_num: int, + smooth: bool = True, + attnw_shift: float = 20.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.head_num = head_num + self.mapq = MLPLayer( + input_dim, + hidden_dim * 1 * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 0), + ) + self.mapkv = MLPLayer( + input_dim, + (hidden_dim + input_dim) * head_num, + bias=False, + precision=precision, + seed=child_seed(seed, 1), + ) + self.head_map = MLPLayer( + input_dim * head_num, + input_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + self.smooth = smooth + self.attnw_shift = attnw_shift + self.precision = precision + + def forward( + self, + g1: paddle.Tensor, # nb x nloc x ng1 + gg1: paddle.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + nb, nloc, nnei = nlist_mask.shape + ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num + assert ni == g1.shape[-1] + assert ni == gg1.shape[-1] + # nb x nloc x nd x nh + g1q = self.mapq(g1).reshape([nb, nloc, nd, nh]) + # nb x nloc x nh x nd + g1q = paddle.transpose(g1q, (0, 1, 3, 2)) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).reshape([nb, nloc, nnei, nd + ni, nh]) + gg1kv = paddle.transpose(gg1kv, (0, 1, 4, 2, 3)) + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 + gg1k, gg1v = paddle.split(gg1kv, [nd, ni], axis=-1) + + # nb x nloc x nh x 1 x nnei + attnw = ( + paddle.matmul(g1q.unsqueeze(-2), paddle.transpose(gg1k, [0, 1, 2, 4, 3])) + / nd**0.5 + ) + # nb x nloc x nh x nnei + attnw = attnw.squeeze(-2) + # mask the attenmap, nb x nloc x 1 x nnei + attnw_mask = ~nlist_mask.unsqueeze(-2) + # nb x nloc x nh x nnei + if self.smooth: + attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift + else: + attnw = attnw.masked_fill( + attnw_mask, + float("-inf"), + ) + attnw = paddle.nn.functional.softmax(attnw, axis=-1) + attnw = attnw.masked_fill( + attnw_mask, + 0.0, + ) + if self.smooth: + attnw = attnw * sw.unsqueeze(-2) + + # nb x nloc x nh x ng1 + ret = ( + paddle.matmul(attnw.unsqueeze(-2), gg1v) + .squeeze(-2) + .reshape([nb, nloc, nh * ni]) + ) + # nb x nloc x ng1 + ret = self.head_map(ret) + return ret + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + return { + "@class": "LocalAtten", + "@version": 1, + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "head_num": self.head_num, + "smooth": self.smooth, + "attnw_shift": self.attnw_shift, + "precision": self.precision, + "mapq": self.mapq.serialize(), + "mapkv": self.mapkv.serialize(), + "head_map": self.head_map.serialize(), + } + + @classmethod + def deserialize(cls, data: dict) -> "LocalAtten": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + mapq = data.pop("mapq") + mapkv = data.pop("mapkv") + head_map = data.pop("head_map") + obj = cls(**data) + obj.mapq = MLPLayer.deserialize(mapq) + obj.mapkv = MLPLayer.deserialize(mapkv) + obj.head_map = MLPLayer.deserialize(head_map) + return obj + + +class RepformerLayer(paddle.nn.Layer): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + update_chnnl_2: bool = True, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + smooth: bool = True, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.ntypes = ntypes + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.axis_neuron = axis_neuron + self.activation_function = activation_function + self.act = ActivationFn(activation_function) + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_attn = update_g1_has_attn + self.update_chnnl_2 = update_chnnl_2 + self.update_g2_has_g1g1 = update_g2_has_g1g1 if self.update_chnnl_2 else False + self.update_g2_has_attn = update_g2_has_attn if self.update_chnnl_2 else False + self.update_h2 = update_h2 if self.update_chnnl_2 else False + del update_g2_has_g1g1, update_g2_has_attn, update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.attn2_has_gate = attn2_has_gate + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.smooth = smooth + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.precision = precision + self.seed = seed + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 0), + ) + ) + + g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) + self.linear1 = MLPLayer( + g1_in_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 1), + ) + self.linear2 = None + self.proj_g1g2 = None + self.proj_g1g1g2 = None + self.attn2g_map = None + self.attn2_mh_apply = None + self.attn2_lm = None + self.attn2_ev_apply = None + self.loc_attn = None + + if self.update_chnnl_2: + self.linear2 = MLPLayer( + g2_dim, + g2_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + if self.g1_out_mlp: + self.g1_self_mlp = MLPLayer( + g1_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 15), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = MLPLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) + if self.update_g2_has_g1g1: + self.proj_g1g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 5), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 6), + ) + ) + if self.update_g2_has_attn or self.update_h2: + self.attn2g_map = Atten2Map( + g2_dim, + attn2_hidden, + attn2_nhead, + attn2_has_gate, + self.smooth, + precision=precision, + seed=child_seed(seed, 7), + ) + if self.update_g2_has_attn: + self.attn2_mh_apply = Atten2MultiHeadApply( + g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 8) + ) + self.attn2_lm = LayerNorm( + g2_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=child_seed(seed, 9), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + g2_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 10), + ) + ) + + if self.update_h2: + self.attn2_ev_apply = Atten2EquiVarApply( + g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) + ) + if self.update_style == "res_residual": + self.h2_residual.append( + get_residual( + 1, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 12), + ) + ) + if self.update_g1_has_attn: + self.loc_attn = LocalAtten( + g1_dim, + attn1_hidden, + attn1_nhead, + self.smooth, + precision=precision, + seed=child_seed(seed, 13), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 14), + ) + ) + + self.g1_residual = nn.ParameterList(self.g1_residual) + self.g2_residual = nn.ParameterList(self.g2_residual) + self.h2_residual = nn.ParameterList(self.h2_residual) + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g1d if not self.g1_out_mlp else 0 + if self.update_g1_has_grrg: + ret += g2d * ax + if self.update_g1_has_drrd: + ret += g1d * ax + if self.update_g1_has_conv and not self.g1_out_conv: + ret += g2d + return ret + + def _update_h2( + self, + h2: paddle.Tensor, + attn: paddle.Tensor, + ) -> paddle.Tensor: + """ + Calculate the attention weights update for pair-wise equivariant rep. + + Parameters + ---------- + h2 + Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3. + attn + Attention weights from g2 attention, with shape nf x nloc x nnei x nnei x nh2. + """ + assert self.attn2_ev_apply is not None + # nf x nloc x nnei x nh2 + h2_1 = self.attn2_ev_apply(attn, h2) + return h2_1 + + def _update_g1_conv( + self, + gg1: paddle.Tensor, + g2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + ) -> paddle.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nb x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + if not self.g1_out_conv: + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape([nb, nloc, nnei, ng2]) + else: + gg1 = gg1.reshape([nb, nloc, nnei, ng1]) + # nb x nloc x nnei x ng2/ng1 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + if not self.smooth: + # normalized by number of neighbors, not smooth + # nb x nloc x 1 + # must use astype here to convert bool to float, otherwise there will be numerical difference from numpy + invnnei = 1.0 / ( + self.epsilon + paddle.sum(nlist_mask.astype(gg1.dtype), axis=-1) + ).unsqueeze(-1) + else: + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * paddle.ones( + (nb, nloc, 1), dtype=gg1.dtype + ).to(device=gg1.place) + if not self.g1_out_conv: + # nb x nloc x ng2 + g1_11 = paddle.sum(g2 * gg1, axis=2) * invnnei + else: + g2 = self.proj_g1g2(g2).reshape([nb, nloc, nnei, ng1]) + # nb x nloc x ng1 + g1_11 = paddle.sum(g2 * gg1, axis=2) * invnnei + return g1_11 + + @staticmethod + def _cal_hg( + g2: paddle.Tensor, + h2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, + ) -> paddle.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not smooth: + # nb x nloc + # must use astype here to convert bool to float, otherwise there will be numerical difference from numpy + if not use_sqrt_nnei: + invnnei = 1.0 / ( + epsilon + paddle.sum(nlist_mask.astype(g2.dtype), axis=-1) + ) + else: + invnnei = 1.0 / ( + epsilon + + paddle.sqrt(paddle.sum(nlist_mask.astype(g2.dtype), axis=-1)) + ) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * paddle.ones( + (nb, nloc, 1, 1), dtype=g2.dtype + ).to(device=g2.place) + else: + invnnei = paddle.rsqrt( + float(nnei) + * paddle.ones((nb, nloc, 1, 1), dtype=g2.dtype).to(device=g2.place) + ) + # nb x nloc x 3 x ng2 + h2g2 = paddle.matmul(paddle.transpose(h2, [0, 1, 3, 2]), g2) * invnnei + return h2g2 + + @staticmethod + def _cal_grrg(h2g2: paddle.Tensor, axis_neuron: int) -> paddle.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + # h2g2m = paddle.split(h2g2, decomp.sec(h2g2.shape[-1], axis_neuron), axis=-1)[0] + h2g2m = h2g2[..., :axis_neuron] # use slice instead of split + # nb x nloc x axis x ng2 + g1_13 = paddle.matmul(paddle.transpose(h2g2m, [0, 1, 3, 2]), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.reshape([nb, nloc, axis_neuron * ng2]) + return g1_13 + + def symmetrization_op( + self, + g2: paddle.Tensor, + h2: paddle.Tensor, + nlist_mask: paddle.Tensor, + sw: paddle.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> paddle.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + # nb x nloc x 3 x ng2 + h2g2 = self._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 + + def _update_g2_g1g1( + self, + g1: paddle.Tensor, # nb x nloc x ng1 + gg1: paddle.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: paddle.Tensor, # nb x nloc x nnei + sw: paddle.Tensor, # nb x nloc x nnei + ) -> paddle.Tensor: + """ + Update the g2 using element-wise dot g1_i * g1_j. + + Parameters + ---------- + g1 + Atomic invariant rep, with shape nb x nloc x ng1. + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + ret = g1.unsqueeze(-2) * gg1 + # nb x nloc x nnei x ng1 + ret = _apply_nlist_mask(ret, nlist_mask) + if self.smooth: + ret = _apply_switch(ret, sw) + return ret + + def forward( + self, + g1_ext: paddle.Tensor, # nf x nall x ng1 + g2: paddle.Tensor, # nf x nloc x nnei x ng2 + h2: paddle.Tensor, # nf x nloc x nnei x 3 + nlist: paddle.Tensor, # nf x nloc x nnei + nlist_mask: paddle.Tensor, # nf x nloc x nnei + sw: paddle.Tensor, # switch func, nf x nloc x nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom channel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom channel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + cal_gg1 = ( + self.update_g1_has_drrd + or self.update_g1_has_conv + or self.update_g1_has_attn + or self.update_g2_has_g1g1 + ) + + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = paddle.split(g1_ext, [nloc, nall - nloc], axis=1) + if paddle.in_dynamic_mode(): + assert [nb, nloc] == g1.shape[:2] + if paddle.in_dynamic_mode(): + assert [nb, nloc, nnei] == h2.shape[:3] + + g2_update: list[paddle.Tensor] = [g2] + h2_update: list[paddle.Tensor] = [h2] + g1_update: list[paddle.Tensor] = [g1] + g1_mlp: list[paddle.Tensor] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + if paddle.in_dynamic_mode(): + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) + + if cal_gg1: + gg1 = _make_nei_g1(g1_ext, nlist) + else: + gg1 = None + + if self.update_chnnl_2: + # mlp(g2) + if paddle.in_dynamic_mode(): + assert self.linear2 is not None + # nb x nloc x nnei x ng2 + g2_1 = self.act(self.linear2(g2)) + g2_update.append(g2_1) + + if self.update_g2_has_g1g1: + # linear(g1_i * g1_j) + if paddle.in_dynamic_mode(): + assert gg1 is not None + if paddle.in_dynamic_mode(): + assert self.proj_g1g1g2 is not None + g2_update.append( + self.proj_g1g1g2(self._update_g2_g1g1(g1, gg1, nlist_mask, sw)) + ) + + if self.update_g2_has_attn or self.update_h2: + # gated_attention(g2, h2) + if paddle.in_dynamic_mode(): + assert self.attn2g_map is not None + # nb x nloc x nnei x nnei x nh + AAg = self.attn2g_map(g2, h2, nlist_mask, sw) + + if self.update_g2_has_attn: + if paddle.in_dynamic_mode(): + assert self.attn2_mh_apply is not None + if paddle.in_dynamic_mode(): + assert self.attn2_lm is not None + # nb x nloc x nnei x ng2 + g2_2 = self.attn2_mh_apply(AAg, g2) + g2_2 = self.attn2_lm(g2_2) + g2_update.append(g2_2) + + if self.update_h2: + # linear_head(attention_weights * h2) + h2_update.append(self._update_h2(h2, AAg)) + + if self.update_g1_has_conv: + if paddle.in_dynamic_mode(): + assert gg1 is not None + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) + + if self.update_g1_has_grrg: + g1_mlp.append( + self.symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + if self.update_g1_has_drrd: + if paddle.in_dynamic_mode(): + assert gg1 is not None + g1_mlp.append( + self.symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=self.smooth, + epsilon=self.epsilon, + ) + ) + + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # conv grrg drrd + g1_1 = self.act(self.linear1(paddle.concat(g1_mlp, axis=-1))) + g1_update.append(g1_1) + + if self.update_g1_has_attn: + assert gg1 is not None + assert self.loc_attn is not None + g1_update.append(self.loc_attn(g1, gg1, nlist_mask, sw)) + + # update + if self.update_chnnl_2: + g2_new = self.list_update(g2_update, "g2") + h2_new = self.list_update(h2_update, "h2") + else: + g2_new, h2_new = g2, h2 + g1_new = self.list_update(g1_update, "g1") + return g1_new, g2_new, h2_new + + def list_update_res_avg( + self, + update_list: list[paddle.Tensor], + ) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + def list_update_res_incr(self, update_list: list[paddle.Tensor]) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + def list_update_res_residual( + self, update_list: list[paddle.Tensor], update_name: str = "g1" + ) -> paddle.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + def list_update( + self, update_list: list[paddle.Tensor], update_name: str = "g1" + ) -> paddle.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 2, + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "update_chnnl_2": self.update_chnnl_2, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "smooth": self.smooth, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, + "ln_eps": self.ln_eps, + "linear1": self.linear1.serialize(), + } + if self.update_chnnl_2: + data.update( + { + "linear2": self.linear2.serialize(), + } + ) + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + if self.update_g2_has_g1g1: + data.update( + { + "proj_g1g1g2": self.proj_g1g1g2.serialize(), + } + ) + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepformerLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 2, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + proj_g1g1g2 = data.pop("proj_g1g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) + variables = data.pop("@variables", {}) + g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) + g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) + h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) + + obj = cls(**data) + obj.linear1 = MLPLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = MLPLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) + if update_g2_has_g1g1: + assert isinstance(proj_g1g1g2, dict) + obj.proj_g1g1g2 = MLPLayer.deserialize(proj_g1g1g2) + if update_g2_has_attn or update_h2: + assert isinstance(attn2g_map, dict) + obj.attn2g_map = Atten2Map.deserialize(attn2g_map) + if update_g2_has_attn: + assert isinstance(attn2_mh_apply, dict) + assert isinstance(attn2_lm, dict) + obj.attn2_mh_apply = Atten2MultiHeadApply.deserialize(attn2_mh_apply) + obj.attn2_lm = LayerNorm.deserialize(attn2_lm) + if update_h2: + assert isinstance(attn2_ev_apply, dict) + obj.attn2_ev_apply = Atten2EquiVarApply.deserialize(attn2_ev_apply) + if update_g1_has_attn: + assert isinstance(loc_attn, dict) + obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = MLPLayer.deserialize(g1_self_mlp) + if update_style == "res_residual": + for ii, t in enumerate(obj.g1_residual): + t.data = to_paddle_tensor(g1_residual[ii]) + for ii, t in enumerate(obj.g2_residual): + t.data = to_paddle_tensor(g2_residual[ii]) + for ii, t in enumerate(obj.h2_residual): + t.data = to_paddle_tensor(h2_residual[ii]) + return obj diff --git a/deepmd/pd/model/descriptor/repformers.py b/deepmd/pd/model/descriptor/repformers.py new file mode 100644 index 0000000000..47d92317df --- /dev/null +++ b/deepmd/pd/model/descriptor/repformers.py @@ -0,0 +1,576 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pd.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pd.model.network.mlp import ( + MLPLayer, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) +from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pd.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pd.utils.utils import ( + ActivationFn, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .repformer_layer import ( + RepformerLayer, +) + + +@DescriptorBlock.register("se_repformer") +@DescriptorBlock.register("se_uni") +class DescrptBlockRepformers(DescriptorBlock): + def __init__( + self, + rcut, + rcut_smth, + sel: int, + ntypes: int, + nlayers: int = 3, + g1_dim=128, + g2_dim=16, + axis_neuron: int = 4, + direct_dist: bool = False, + update_g1_has_conv: bool = True, + update_g1_has_drrd: bool = True, + update_g1_has_grrg: bool = True, + update_g1_has_attn: bool = True, + update_g2_has_g1g1: bool = True, + update_g2_has_attn: bool = True, + update_h2: bool = False, + attn1_hidden: int = 64, + attn1_nhead: int = 4, + attn2_hidden: int = 16, + attn2_nhead: int = 4, + attn2_has_gate: bool = False, + activation_function: str = "tanh", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + set_davg_zero: bool = True, + smooth: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + trainable_ln: bool = True, + ln_eps: Optional[float] = 1e-5, + seed: Optional[Union[int, list[int]]] = None, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, + ) -> None: + r""" + The repformer descriptor block. + + Parameters + ---------- + rcut : float + The cut-off radius. + rcut_smth : float + Where to start smoothing. For example the 1/r term is smoothed from rcut to rcut_smth. + sel : int + Maximally possible number of selected neighbors. + ntypes : int + Number of element types + nlayers : int, optional + Number of repformer layers. + g1_dim : int, optional + Dimension of the first graph convolution layer. + g2_dim : int, optional + Dimension of the second graph convolution layer. + axis_neuron : int, optional + Size of the submatrix of G (embedding matrix). + direct_dist : bool, optional + Whether to use direct distance information (1/r term) in the repformer block. + update_g1_has_conv : bool, optional + Whether to update the g1 rep with convolution term. + update_g1_has_drrd : bool, optional + Whether to update the g1 rep with the drrd term. + update_g1_has_grrg : bool, optional + Whether to update the g1 rep with the grrg term. + update_g1_has_attn : bool, optional + Whether to update the g1 rep with the localized self-attention. + update_g2_has_g1g1 : bool, optional + Whether to update the g2 rep with the g1xg1 term. + update_g2_has_attn : bool, optional + Whether to update the g2 rep with the gated self-attention. + update_h2 : bool, optional + Whether to update the h2 rep. + attn1_hidden : int, optional + The hidden dimension of localized self-attention to update the g1 rep. + attn1_nhead : int, optional + The number of heads in localized self-attention to update the g1 rep. + attn2_hidden : int, optional + The hidden dimension of gated self-attention to update the g2 rep. + attn2_nhead : int, optional + The number of heads in gated self-attention to update the g2 rep. + attn2_has_gate : bool, optional + Whether to use gate in the gated self-attention to update the g2 rep. + activation_function : str, optional + The activation function in the embedding net. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + smooth : bool, optional + Whether to use smoothness in processes such as attention weights calculation. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable_ln : bool, optional + Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. + ln_eps : float, optional + The epsilon value for layer normalization. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.ntypes = ntypes + self.nlayers = nlayers + sel = [sel] if isinstance(sel, int) else sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + self.g1_dim = g1_dim + self.g2_dim = g2_dim + self.update_g1_has_conv = update_g1_has_conv + self.update_g1_has_drrd = update_g1_has_drrd + self.update_g1_has_grrg = update_g1_has_grrg + self.update_g1_has_attn = update_g1_has_attn + self.update_g2_has_g1g1 = update_g2_has_g1g1 + self.update_g2_has_attn = update_g2_has_attn + self.update_h2 = update_h2 + self.attn1_hidden = attn1_hidden + self.attn1_nhead = attn1_nhead + self.attn2_has_gate = attn2_has_gate + self.attn2_hidden = attn2_hidden + self.attn2_nhead = attn2_nhead + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.direct_dist = direct_dist + self.act = ActivationFn(activation_function) + self.smooth = smooth + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.prec = PRECISION_DICT[precision] + self.trainable_ln = trainable_ln + self.ln_eps = ln_eps + self.epsilon = 1e-4 + self.seed = seed + + self.g2_embd = MLPLayer( + 1, self.g2_dim, precision=precision, seed=child_seed(seed, 0) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepformerLayer( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.g1_dim, + self.g2_dim, + axis_neuron=self.axis_neuron, + update_chnnl_2=(ii != nlayers - 1), + update_g1_has_conv=self.update_g1_has_conv, + update_g1_has_drrd=self.update_g1_has_drrd, + update_g1_has_grrg=self.update_g1_has_grrg, + update_g1_has_attn=self.update_g1_has_attn, + update_g2_has_g1g1=self.update_g2_has_g1g1, + update_g2_has_attn=self.update_g2_has_attn, + update_h2=self.update_h2, + attn1_hidden=self.attn1_hidden, + attn1_nhead=self.attn1_nhead, + attn2_has_gate=self.attn2_has_gate, + attn2_hidden=self.attn2_hidden, + attn2_nhead=self.attn2_nhead, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + smooth=self.smooth, + trainable_ln=self.trainable_ln, + ln_eps=self.ln_eps, + precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = paddle.nn.LayerList(layers) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.g2_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.g1_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: paddle.Tensor, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + extended_atype_embd: Optional[paddle.Tensor] = None, + mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + if comm_dict is None: + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + nlist_mask = nlist != -1 + sw = paddle.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + if comm_dict is None: + if paddle.in_dynamic_mode(): + assert isinstance(extended_atype_embd, paddle.Tensor) # for jit + atype_embd = extended_atype_embd[:, :nloc, :] + if paddle.in_dynamic_mode(): + assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] + else: + atype_embd = extended_atype_embd + if paddle.in_dynamic_mode(): + assert isinstance(atype_embd, paddle.Tensor) # for jit + g1 = self.act(atype_embd) + ng1 = g1.shape[-1] + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + if not self.direct_dist: + g2, h2 = paddle.split(dmatrix, [1, 3], axis=-1) + else: + # g2, h2 = paddle.linalg.norm(diff, axis=-1, keepdim=True), diff + g2, h2 = paddle.linalg.norm(diff, axis=-1, keepdim=True), diff + g2 = g2 / self.rcut + h2 = h2 / self.rcut + # nb x nloc x nnei x ng2 + g2 = self.act(self.g2_embd(g2)) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.reshape([nframes, nall]) + .unsqueeze(-1) + .expand([-1, -1, self.g1_dim]) + ) + for idx, ll in enumerate(self.layers): + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + g1_ext = paddle.take_along_axis(g1, axis=1, indices=mapping) + else: + raise NotImplementedError("Not implemented yet") + # has_spin = "has_spin" in comm_dict + # if not has_spin: + # n_padding = nall - nloc + # g1 = paddle.nn.functional.pad( + # g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 + # ) + # real_nloc = nloc + # real_nall = nall + # else: + # # for spin + # real_nloc = nloc // 2 + # real_nall = nall // 2 + # real_n_padding = real_nall - real_nloc + # g1_real, g1_virtual = paddle.split( + # g1, [real_nloc, real_nloc], axis=1 + # ) + # # mix_g1: nb x real_nloc x (ng1 * 2) + # mix_g1 = paddle.concat([g1_real, g1_virtual], axis=2) + # # nb x real_nall x (ng1 * 2) + # g1 = paddle.nn.functional.pad( + # mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + # ) + + # assert "send_list" in comm_dict + # assert "send_proc" in comm_dict + # assert "recv_proc" in comm_dict + # assert "send_num" in comm_dict + # assert "recv_num" in comm_dict + # assert "communicator" in comm_dict + # ret = paddle.ops.deepmd.border_op( + # comm_dict["send_list"], + # comm_dict["send_proc"], + # comm_dict["recv_proc"], + # comm_dict["send_num"], + # comm_dict["recv_num"], + # g1, + # comm_dict["communicator"], + # paddle.to_tensor( + # real_nloc, + # dtype=paddle.int32, + # place=env.DEVICE, + # ), # should be int of c++ + # paddle.to_tensor( + # real_nall - real_nloc, + # dtype=paddle.int32, + # place=env.DEVICE, + # ), # should be int of c++ + # ) + # g1_ext = ret[0].unsqueeze(0) + # if has_spin: + # g1_real_ext, g1_virtual_ext = paddle.split( + # g1_ext, [ng1, ng1], axis=2 + # ) + # g1_ext = concat_switch_virtual( + # g1_real_ext, g1_virtual_ext, real_nloc + # ) + g1, g2, h2 = ll.forward( + g1_ext, + g2, + h2, + nlist, + nlist_mask, + sw, + ) + + # nb x nloc x 3 x ng2 + h2g2 = RepformerLayer._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) + # (nb x nloc) x ng2 x 3 + rot_mat = paddle.transpose(h2g2, (0, 1, 3, 2)) + + return g1, g2, h2, rot_mat.reshape([nframes, nloc, self.dim_emb, 3]), sw + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + paddle.assign(paddle.to_tensor(mean).to(device=env.DEVICE), self.mean) # pylint: disable=no-explicit-dtype + paddle.assign(paddle.to_tensor(stddev).to(device=env.DEVICE), self.stddev) # pylint: disable=no-explicit-dtype + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pd/model/descriptor/se_t_tebd.py b/deepmd/pd/model/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..a8b9a6a417 --- /dev/null +++ b/deepmd/pd/model/descriptor/se_t_tebd.py @@ -0,0 +1,931 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import paddle + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pd.model.descriptor import ( + DescriptorBlock, +) +from deepmd.pd.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pd.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pd.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pd.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pd.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pd.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) + + +@BaseDescriptor.register("se_e3_tebd") +class DescrptSeTTebd(BaseDescriptor, paddle.nn.Layer): + r"""Construct an embedding net that takes angles between two neighboring atoms and type embeddings as input. + + Parameters + ---------- + rcut + The cut-off radius + rcut_smth + From where the environment matrix should be smoothed + sel : Union[list[int], int] + list[int]: sel[i] specifies the maxmum number of type i atoms in the cut-off radius + int: the total maxmum number of atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net + tebd_dim : int + Dimension of the type embedding + tebd_input_mode : str + The input mode of the type embedding. Supported modes are ["concat", "strip"]. + - "concat": Concatenate the type embedding with the smoothed angular information as the union input for the embedding network. + - "strip": Use a separated embedding network for the type embedding and combine the output with the angular embedding network output. + resnet_dt + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + set_davg_zero + Set the shift of embedding net input to zero. + activation_function + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + env_protection: float + Protection parameter to prevent division by zero errors during environment matrix calculations. + exclude_types : list[tuple[int, int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + trainable + If the weights of embedding net are trainable. + seed + Random seed for initializing the network parameters. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + concat_output_tebd: bool + Whether to concat type embedding at the output of the descriptor. + use_econf_tebd: bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + smooth: bool + Whether to use smooth process in calculation. + + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [2, 4, 8], + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + set_davg_zero: bool = True, + activation_function: str = "tanh", + env_protection: float = 0.0, + exclude_types: list[tuple[int, int]] = [], + precision: str = "float64", + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + type_map: Optional[list[str]] = None, + concat_output_tebd: bool = True, + use_econf_tebd: bool = False, + use_tebd_bias=False, + smooth: bool = True, + ) -> None: + super().__init__() + self.se_ttebd = DescrptBlockSeTTebd( + rcut, + rcut_smth, + sel, + ntypes, + neuron=neuron, + tebd_dim=tebd_dim, + tebd_input_mode=tebd_input_mode, + set_davg_zero=set_davg_zero, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + exclude_types=exclude_types, + env_protection=env_protection, + smooth=smooth, + seed=child_seed(seed, 1), + ) + self.prec = PRECISION_DICT[precision] + self.use_econf_tebd = use_econf_tebd + self.type_map = type_map + self.smooth = smooth + self.type_embedding = TypeEmbedNet( + ntypes, + tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=use_econf_tebd, + type_map=type_map, + use_tebd_bias=use_tebd_bias, + ) + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.concat_output_tebd = concat_output_tebd + self.trainable = trainable + # set trainable + for param in self.parameters(): + param.stop_gradient = not trainable + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.se_ttebd.get_rcut() + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.se_ttebd.get_rcut_smth() + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return self.se_ttebd.get_nsel() + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.se_ttebd.get_sel() + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.se_ttebd.get_ntypes() + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + ret = self.se_ttebd.get_dim_out() + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + return self.se_ttebd.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return self.se_ttebd.mixed_types() + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.se_ttebd.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.se_ttebd.need_sorted_nlist_for_lower() + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.se_ttebd.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA1 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in both type_embedding and se_ttebd + if shared_level == 0: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + self.se_ttebd.share_params(base_class.se_ttebd, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._sub_layers["type_embedding"] = base_class._sub_layers[ + "type_embedding" + ] + # Other shared levels + else: + raise NotImplementedError + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ): + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + return self.se_ttebd.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: paddle.Tensor, + stddev: paddle.Tensor, + ) -> None: + """Update mean and stddev for descriptor.""" + self.se_ttebd.mean = mean + self.se_ttebd.stddev = stddev + + def get_stat_mean_and_stddev(self) -> tuple[paddle.Tensor, paddle.Tensor]: + """Get mean and stddev for descriptor.""" + return self.se_ttebd.mean, self.se_ttebd.stddev + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + obj = self.se_ttebd + obj.ntypes = len(type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + obj.reinit_exclude(map_pair_exclude_types(obj.exclude_types, remap_index)) + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + obj, + type_map, + des_with_stat=model_with_new_type_stat.se_ttebd + if model_with_new_type_stat is not None + else None, + ) + obj["davg"] = obj["davg"][remap_index] + obj["dstd"] = obj["dstd"][remap_index] + + def serialize(self) -> dict: + obj = self.se_ttebd + data = { + "@class": "Descriptor", + "type": "se_e3_tebd", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "concat_output_tebd": self.concat_output_tebd, + "use_econf_tebd": self.use_econf_tebd, + "type_map": self.type_map, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], + "embeddings": obj.filter_layers.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.embedding.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "smooth": self.smooth, + "@variables": { + "davg": obj["davg"].numpy(), + "dstd": obj["dstd"].numpy(), + }, + "trainable": self.trainable, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.filter_layers_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeTTebd": + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + def t_cvt(xx): + return paddle.to_tensor(xx, dtype=obj.se_ttebd.prec).to(device=env.DEVICE) + + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + obj.se_ttebd["davg"] = t_cvt(variables["davg"]) + obj.se_ttebd["dstd"] = t_cvt(variables["dstd"]) + obj.se_ttebd.filter_layers = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.se_ttebd.filter_layers_strip = NetworkCollection.deserialize( + embeddings_strip + ) + return obj + + def forward( + self, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + nlist: paddle.Tensor, + mapping: Optional[paddle.Tensor] = None, + comm_dict: Optional[dict[str, paddle.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + del mapping + nframes, nloc, nnei = nlist.shape + nall = extended_coord.reshape([nframes, -1]).shape[1] // 3 + g1_ext = self.type_embedding(extended_atype) + g1_inp = g1_ext[:, :nloc, :] + if self.tebd_input_mode in ["strip"]: + type_embedding = self.type_embedding.get_full_embedding(g1_ext.place) + else: + type_embedding = None + g1, _, _, _, sw = self.se_ttebd( + nlist, + extended_coord, + extended_atype, + g1_ext, + mapping=None, + type_embedding=type_embedding, + ) + if self.concat_output_tebd: + g1 = paddle.concat([g1, g1_inp], axis=-1) + + return ( + g1.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + None, + None, + None, + sw.to(dtype=env.GLOBAL_PD_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, sel = UpdateSel().update_one_sel( + train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True + ) + local_jdata_cpy["sel"] = sel[0] + return local_jdata_cpy, min_nbor_dist + + +@DescriptorBlock.register("se_ttebd") +class DescrptBlockSeTTebd(DescriptorBlock): + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: Union[list[int], int], + ntypes: int, + neuron: list = [25, 50, 100], + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + set_davg_zero: bool = True, + activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + smooth: bool = True, + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.rcut = float(rcut) + self.rcut_smth = float(rcut_smth) + self.neuron = neuron + self.filter_neuron = self.neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.env_protection = env_protection + self.seed = seed + self.smooth = smooth + + if isinstance(sel, int): + sel = [sel] + + self.ntypes = ntypes + self.sel = sel + self.sec = self.sel + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = paddle.zeros(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + stddev = paddle.ones(wanted_shape, dtype=env.GLOBAL_PD_FLOAT_PRECISION).to( + device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.tebd_dim_input = self.tebd_dim * 2 + if self.tebd_input_mode in ["concat"]: + self.embd_input_dim = 1 + self.tebd_dim_input + else: + self.embd_input_dim = 1 + + self.filter_layers = None + self.filter_layers_strip = None + filter_layers = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 1), + ) + self.filter_layers = filter_layers + if self.tebd_input_mode in ["strip"]: + filter_layers_strip = NetworkCollection( + ndim=0, ntypes=self.ntypes, network_type="embedding_network" + ) + filter_layers_strip[0] = EmbeddingNet( + self.tebd_dim_input, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + seed=child_seed(self.seed, 2), + ) + self.filter_layers_strip = filter_layers_strip + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_emb(self) -> int: + """Returns the output dimension of embedding.""" + return self.filter_neuron[-1] + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.filter_neuron[-1] + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.tebd_dim + + @property + def dim_emb(self): + """Returns the output dimension of embedding.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `paddle.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + paddle.assign(paddle.to_tensor(mean).to(device=env.DEVICE), self.mean) # pylint: disable=no-explicit-dtype + paddle.assign(paddle.to_tensor(stddev).to(device=env.DEVICE), self.stddev) # pylint: disable=no-explicit-dtype + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: paddle.Tensor, + extended_coord: paddle.Tensor, + extended_atype: paddle.Tensor, + extended_atype_embd: Optional[paddle.Tensor] = None, + mapping: Optional[paddle.Tensor] = None, + type_embedding: Optional[paddle.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall x nt + extended_atype_embd + The extended type embedding of atoms. shape: nf x nall + mapping + The index mapping, not required by this descriptor. + type_embedding + Full type embeddings. shape: (ntypes+1) x nt + Required for stripped type embeddings. + + Returns + ------- + result + The descriptor. shape: nf x nloc x (ng x axis_neuron) + g2 + The rotationally invariant pair-partical representation. + shape: nf x nloc x nnei x ng + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + del mapping + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + nb = nframes + nall = extended_coord.reshape([nb, -1, 3]).shape[1] + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + protection=self.env_protection, + ) + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = paddle.where(exclude_mask != 0, nlist, paddle.full_like(nlist, -1)) + nlist_mask = nlist != -1 + nlist = paddle.where(nlist == -1, paddle.zeros_like(nlist), nlist) + sw = paddle.squeeze(sw, -1) + # nf x nall x nt + nt = extended_atype_embd.shape[-1] + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + # (nb x nloc) x nnei + exclude_mask = exclude_mask.reshape([nb * nloc, nnei]) + assert self.filter_layers is not None + # nfnl x nnei x 4 + dmatrix = dmatrix.reshape([-1, self.nnei, 4]) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + rr = rr * exclude_mask[:, :, None].astype(rr.dtype) + + # nfnl x nt_i x 3 + rr_i = rr[:, :, 1:] + # nfnl x nt_j x 3 + rr_j = rr[:, :, 1:] + # nfnl x nt_i x nt_j + # env_ij = paddle.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = ( + # ij1m x i1km -> ijkm -> ijk + rr_i.unsqueeze(2) * rr_j.unsqueeze(1) + ).sum(-1) + # nfnl x nt_i x nt_j x 1 + ss = env_ij.unsqueeze(-1) + if self.tebd_input_mode in ["concat"]: + atype_tebd_ext = extended_atype_embd + # nb x (nloc x nnei) x nt + index = nlist.reshape([nb, nloc * nnei]).unsqueeze(-1).expand([-1, -1, nt]) + # nb x (nloc x nnei) x nt + # atype_tebd_nlist = paddle.take_along_axis(atype_tebd_ext, axis=1, index=index) + atype_tebd_nlist = paddle.take_along_axis( + atype_tebd_ext, axis=1, indices=index + ) + # nb x nloc x nnei x nt + atype_tebd_nlist = atype_tebd_nlist.reshape([nb, nloc, nnei, nt]) + # nfnl x nnei x tebd_dim + nlist_tebd = atype_tebd_nlist.reshape([nfnl, nnei, self.tebd_dim]) + # nfnl x nt_i x nt_j x tebd_dim + nlist_tebd_i = nlist_tebd.unsqueeze(2).expand([-1, -1, self.nnei, -1]) + nlist_tebd_j = nlist_tebd.unsqueeze(1).expand([-1, self.nnei, -1, -1]) + # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) + ss = paddle.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) + # nfnl x nt_i x nt_j x ng + gg = self.filter_layers.networks[0](ss) + elif self.tebd_input_mode in ["strip"]: + # nfnl x nt_i x nt_j x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + assert type_embedding is not None + ng = self.filter_neuron[-1] + ntypes_with_padding = type_embedding.shape[0] + # nf x (nl x nnei) + nlist_index = nlist.reshape([nb, nloc * nnei]) + # nf x (nl x nnei) + nei_type = paddle.take_along_axis( + extended_atype, indices=nlist_index, axis=1 + ) + # nfnl x nnei + nei_type = nei_type.reshape([nfnl, nnei]) + # nfnl x nnei x nnei + nei_type_i = nei_type.unsqueeze(2).expand([-1, -1, nnei]) + nei_type_j = nei_type.unsqueeze(1).expand([-1, nnei, -1]) + idx_i = nei_type_i * ntypes_with_padding + idx_j = nei_type_j + # (nf x nl x nt_i x nt_j) x ng + idx = ( + (idx_i + idx_j) + .reshape([-1, 1]) + .expand([-1, ng]) + .astype(paddle.int64) + .to(paddle.int64) + ) + # ntypes * (ntypes) * nt + type_embedding_i = paddle.tile( + type_embedding.reshape([ntypes_with_padding, 1, nt]), + [1, ntypes_with_padding, 1], + ) + # (ntypes) * ntypes * nt + type_embedding_j = paddle.tile( + type_embedding.reshape([1, ntypes_with_padding, nt]), + [ntypes_with_padding, 1, 1], + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = paddle.concat( + [type_embedding_i, type_embedding_j], -1 + ).reshape([-1, nt * 2]) + tt_full = self.filter_layers_strip.networks[0](two_side_type_embedding) + # (nfnl x nt_i x nt_j) x ng + gg_t = paddle.take_along_axis(tt_full, indices=idx, axis=0) + # (nfnl x nt_i x nt_j) x ng + gg_t = gg_t.reshape([nfnl, nnei, nnei, ng]) + if self.smooth: + gg_t = ( + gg_t + * sw.reshape([nfnl, self.nnei, 1, 1]) + * sw.reshape([nfnl, 1, self.nnei, 1]) + ) + # nfnl x nt_i x nt_j x ng + gg = gg_s * gg_t + gg_s + else: + raise NotImplementedError + + # nfnl x ng + # res_ij = paddle.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = ( + # ijk1 x ijkm -> ijkm -> im + env_ij.unsqueeze(-1) * gg + ).sum([1, 2]) + res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) + # nf x nl x ng + result = res_ij.reshape([nframes, nloc, self.filter_neuron[-1]]) + return ( + result, + None, + None, + None, + sw, + ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pd/model/task/fitting.py b/deepmd/pd/model/task/fitting.py index d9db44aff5..6e96b7b081 100644 --- a/deepmd/pd/model/task/fitting.py +++ b/deepmd/pd/model/task/fitting.py @@ -211,8 +211,8 @@ def __init__( if self.dim_case_embd > 0: self.register_buffer( "case_embd", - paddle.zeros(self.dim_case_embd, dtype=self.prec, place=device), - # paddle.eye(self.dim_case_embd, dtype=self.prec, place=device)[0], + paddle.zeros(self.dim_case_embd, dtype=self.prec).to(device=device), + # paddle.eye(self.dim_case_embd, dtype=self.prec).to(device=device)[0], ) else: self.case_embd = None diff --git a/deepmd/pd/utils/multi_task.py b/deepmd/pd/utils/multi_task.py index 680dc53c79..321883c12e 100644 --- a/deepmd/pd/utils/multi_task.py +++ b/deepmd/pd/utils/multi_task.py @@ -96,7 +96,9 @@ def preprocess_shared_params(model_config): shared_links = {} type_map_keys = [] - def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None): + def replace_one_item( + params_dict, key_type, key_in_dict, suffix="", index=None + ) -> None: shared_type = key_type shared_key = key_in_dict shared_level = 0 diff --git a/deepmd/pd/utils/spin.py b/deepmd/pd/utils/spin.py new file mode 100644 index 0000000000..934fb3762a --- /dev/null +++ b/deepmd/pd/utils/spin.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import paddle + + +def concat_switch_virtual( + extended_tensor, + extended_tensor_virtual, + nloc: int, +): + """ + Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms. + - [:, :nloc]: original nloc real atoms. + - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. + - [:, nloc + nloc: nloc + nall]: ghost real atoms. + - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. + """ + nframes, nall = extended_tensor.shape[:2] + out_shape = list(extended_tensor.shape) + out_shape[1] *= 2 + extended_tensor_updated = paddle.zeros( + out_shape, + dtype=extended_tensor.dtype, + device=extended_tensor.place, + ) + extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] + extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc] + extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:] + extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] + return extended_tensor_updated.reshape(out_shape) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 72c0967a78..ef840bf9d7 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -17,6 +17,7 @@ from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, CommonTest, parameterized, @@ -34,6 +35,12 @@ from deepmd.jax.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2JAX else: DescrptDPA2JAX = None + +if INSTALLED_PD: + from deepmd.pd.model.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2PD +else: + DescrptDPA2PD = None + if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2Strict else: @@ -214,6 +221,39 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pd(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repinit_use_three_body, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PD or precision == "bfloat16" + @property def skip_dp(self) -> bool: ( @@ -286,6 +326,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT + pd_class = DescrptDPA2PD jax_class = DescrptDPA2JAX array_api_strict_class = DescrptDPA2Strict args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) @@ -383,6 +424,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index bb4a5db6e7..9cdca9bde3 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -17,6 +17,7 @@ from ..common import ( INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, + INSTALLED_PD, INSTALLED_PT, CommonTest, parameterized, @@ -34,6 +35,10 @@ from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX else: DescrptSeTTebdJAX = None +if INSTALLED_PD: + from deepmd.pd.model.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdPD +else: + DescrptSeTTebdPD = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.se_t_tebd import ( DescrptSeTTebd as DescrptSeTTebdStrict, @@ -146,12 +151,14 @@ def skip_tf(self) -> bool: ) = self.param return True + skip_pd = not INSTALLED_PD skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + pd_class = DescrptSeTTebdPD jax_class = DescrptSeTTebdJAX array_api_strict_class = DescrptSeTTebdStrict args = descrpt_se_e3_tebd_args().append(Argument("ntypes", int, optional=False)) @@ -243,6 +250,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_pd(self, pd_obj: Any) -> Any: + return self.eval_pd_descriptor( + pd_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, diff --git a/source/tests/pd/model/models/dpa2.json b/source/tests/pd/model/models/dpa2.json new file mode 100644 index 0000000000..f83e319de3 --- /dev/null +++ b/source/tests/pd/model/models/dpa2.json @@ -0,0 +1,57 @@ +{ + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "rcut_smth": 2.0, + "nsel": 30, + "neuron": [ + 2, + 4, + 8 + ], + "axis_neuron": 4, + "activation_function": "tanh" + + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 10, + "nlayers": 12, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": true, + "update_g2_has_g1g1": true, + "update_g2_has_attn": true, + "attn2_has_gate": true, + "use_sqrt_nnei": false, + "g1_out_conv": false, + "g1_out_mlp": false + }, + "seed": 1, + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } +} diff --git a/source/tests/pd/model/models/dpa2.pd b/source/tests/pd/model/models/dpa2.pd new file mode 100644 index 0000000000..650f0c144e Binary files /dev/null and b/source/tests/pd/model/models/dpa2.pd differ diff --git a/source/tests/pd/model/test_autodiff.py b/source/tests/pd/model/test_autodiff.py index 1bd9dd0d0f..8442844a24 100644 --- a/source/tests/pd/model/test_autodiff.py +++ b/source/tests/pd/model/test_autodiff.py @@ -60,7 +60,7 @@ def stretch_box(old_coord, old_box, new_box): class ForceTest: def test( self, - ): + ) -> None: env.enable_prim(True) places = 5 delta = 1e-5 @@ -86,10 +86,10 @@ def np_infer_coord( ): result = eval_model( self.model, - paddle.to_tensor(coord).to(device=env.DEVICE).unsqueeze(0), + paddle.to_tensor(coord, place=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, - spins=paddle.to_tensor(spin).to(device=env.DEVICE).unsqueeze(0), + spins=paddle.to_tensor(spin, place=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -100,10 +100,10 @@ def np_infer_spin( ): result = eval_model( self.model, - paddle.to_tensor(coord).to(device=env.DEVICE).unsqueeze(0), + paddle.to_tensor(coord, place=env.DEVICE).unsqueeze(0), cell.unsqueeze(0), atype, - spins=paddle.to_tensor(spin).to(device=env.DEVICE).unsqueeze(0), + spins=paddle.to_tensor(spin, place=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -133,7 +133,7 @@ def ff_spin(_spin): class VirialTest: def test( self, - ): + ) -> None: places = 5 delta = 1e-4 natoms = 5 @@ -153,10 +153,10 @@ def np_infer( ): result = eval_model( self.model, - paddle.to_tensor(stretch_box(coord, cell, new_cell)) - .to(device="cpu") - .unsqueeze(0), - paddle.to_tensor(new_cell).to(device="cpu").unsqueeze(0), + paddle.to_tensor( + stretch_box(coord, cell, new_cell), place="cpu" + ).unsqueeze(0), + paddle.to_tensor(new_cell, place="cpu").unsqueeze(0), atype, ) # detach @@ -177,36 +177,35 @@ def ff(bb): class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_se_e2_a) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Force(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1Virial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2Force(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -214,7 +213,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelDPAUniVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -222,7 +221,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelHybridForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -230,7 +229,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelHybridVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_hybrid) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) @@ -238,7 +237,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBLForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -246,7 +245,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBLVirial(unittest.TestCase, VirialTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_zbl) self.type_split = False self.model = get_model(model_params).to(env.DEVICE) @@ -254,7 +253,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinSeAForce(unittest.TestCase, ForceTest): - def setUp(self): + def setUp(self) -> None: model_params = copy.deepcopy(model_spin) self.type_split = False self.test_spin = True diff --git a/source/tests/pd/model/test_descriptor_dpa2.py b/source/tests/pd/model/test_descriptor_dpa2.py new file mode 100644 index 0000000000..12017bb840 --- /dev/null +++ b/source/tests/pd/model/test_descriptor_dpa2.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import paddle + +from deepmd.pd.model.descriptor import ( + DescrptDPA2, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +CUR_DIR = os.path.dirname(__file__) + + +class TestDPA2(unittest.TestCase): + def setUp(self): + cell = [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + self.cell = ( + paddle.to_tensor(cell, dtype=env.GLOBAL_PD_FLOAT_PRECISION) + .reshape([1, 3, 3]) + .to(device=env.DEVICE) + ) + coord = [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + self.coord = ( + paddle.to_tensor(coord, dtype=env.GLOBAL_PD_FLOAT_PRECISION) + .reshape([1, -1, 3]) + .to(device=env.DEVICE) + ) + self.atype = ( + paddle.to_tensor([0, 0, 0, 1, 1], dtype=paddle.int32) + .reshape([1, -1]) + .to(device=env.DEVICE) + ) + self.ref_d = paddle.to_tensor( + [ + 8.435412613327306630e-01, + -4.717109614540972440e-01, + -1.812643456954206256e00, + -2.315248767961955167e-01, + -7.112973006771171613e-01, + -4.162041919507591392e-01, + -1.505159810095323181e00, + -1.191652416985768403e-01, + 8.439214937875325617e-01, + -4.712976890460106594e-01, + -1.812605149396642856e00, + -2.307222236291133766e-01, + -7.115427800870099961e-01, + -4.164729253167227530e-01, + -1.505483119125936797e00, + -1.191288524278367872e-01, + 8.286420823261241297e-01, + -4.535033763979030574e-01, + -1.787877160970498425e00, + -1.961763875645104460e-01, + -7.475459187804838201e-01, + -5.231446874663764346e-01, + -1.488399984491664219e00, + -3.974117581747104583e-02, + 8.283793431613817315e-01, + -4.551551577556525729e-01, + -1.789253136645859943e00, + -1.977673627726055372e-01, + -7.448826048241211639e-01, + -5.161350182531234676e-01, + -1.487589463573479209e00, + -4.377376017839779143e-02, + 8.295404560710329944e-01, + -4.492219258475603216e-01, + -1.784484611185287450e00, + -1.901182059718481143e-01, + -7.537407667483000395e-01, + -5.384371277650709109e-01, + -1.490368056268364549e00, + -3.073744832541754762e-02, + ], + dtype=env.GLOBAL_PD_FLOAT_PRECISION, + place=env.DEVICE, + ) + self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pd" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pd" + + def test_descriptor(self) -> None: + with open(Path(CUR_DIR) / "models" / "dpa2.json") as fp: + self.model_json = json.load(fp) + model_dpa2 = self.model_json + ntypes = len(model_dpa2["type_map"]) + dparams = model_dpa2["descriptor"] + dparams["ntypes"] = ntypes + assert dparams["type"] == "dpa2" + dparams.pop("type") + dparams["concat_output_tebd"] = False + dparams["use_tebd_bias"] = True + des = DescrptDPA2( + **dparams, + ).to(env.DEVICE) + target_dict = des.state_dict() + source_dict = paddle.load(str(self.file_model_param)) + # type_embd of repformer is removed + source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") + type_embd_dict = paddle.load(str(self.file_type_embed)) + target_dict = translate_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, + ) + des.set_state_dict(target_dict) + + coord = self.coord + atype = self.atype + box = self.cell + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + des.get_rcut(), + des.get_sel(), + mixed_types=des.mixed_types(), + box=box, + ) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + self.assertAlmostEqual(6.0, des.get_rcut()) + self.assertEqual(30, des.get_nsel()) + self.assertEqual(2, des.get_ntypes()) + np.testing.assert_allclose( + descriptor.reshape([-1]).numpy(), self.ref_d.numpy(), atol=1e-10, rtol=1e-10 + ) + + dparams["concat_output_tebd"] = True + des = DescrptDPA2( + **dparams, + ).to(env.DEVICE) + descriptor, env_mat, diff, rot_mat, sw = des( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + self.assertEqual(descriptor.shape[-1], des.get_dim_out()) + + +def translate_type_embd_dicts_to_dpa2( + target_dict, + source_dict, + type_embd_dict, +): + all_keys = list(target_dict.keys()) + record = [False for ii in all_keys] + for kk, vv in source_dict.items(): + record[all_keys.index(kk)] = True + target_dict[kk] = vv + assert len(type_embd_dict.keys()) == 2 + it = iter(type_embd_dict.keys()) + for _ in range(2): + kk = next(it) + tk = "type_embedding." + kk + record[all_keys.index(tk)] = True + target_dict[tk] = type_embd_dict[kk] + record[all_keys.index("repinit.compress_data.0")] = True + record[all_keys.index("repinit.compress_info.0")] = True + assert all(record) + return target_dict diff --git a/source/tests/pd/model/test_dpa2.py b/source/tests/pd/model/test_dpa2.py new file mode 100644 index 0000000000..f441007cad --- /dev/null +++ b/source/tests/pd/model/test_dpa2.py @@ -0,0 +1,333 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import paddle + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.pd.model.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pd.utils import ( + env, +) +from deepmd.pd.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PD_FLOAT_PRECISION + + +class TestDescrptDPA2(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ect, + ns, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [True, False], # repformer_update_g1_has_conv + [True, False], # repformer_update_g1_has_drrd + [True, False], # repformer_update_g1_has_grrg + [ + False, + ], # repformer_update_g1_has_attn + [ + False, + ], # repformer_update_g2_has_g1g1 + [True, False], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + [False, True], # use_econf_tebd + [ + False, + True, + ], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) + ): + if ns and not rp1d and not rp1g: + continue + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode=riti, + set_davg_zero=riz, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=rp1c, + update_g1_has_drrd=rp1d, + update_g1_has_grrg=rp1g, + update_g1_has_attn=rp1a, + update_g2_has_g1g1=rp2g, + update_g2_has_attn=rp2a, + update_h2=rph, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=rp2gate, + update_style=rus, + set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, + ) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repinit.mean = paddle.to_tensor(davg, dtype=dtype).to(device=env.DEVICE) + dd0.repinit.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.mean = paddle.to_tensor(davg_2, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.stddev = paddle.to_tensor(dstd_2, dtype=dtype).to( + device=env.DEVICE + ) + rd0, _, _, _, _ = dd0( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype="int64").to(device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + paddle.to_tensor(self.coord_ext, dtype=dtype).to(device=env.DEVICE), + paddle.to_tensor(self.atype_ext, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.nlist, dtype="int64").to(device=env.DEVICE), + paddle.to_tensor(self.mapping, dtype="int64").to(device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + + @unittest.skip("skip jit in paddle temporally") + def test_jit( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + riti, + riz, + rp1c, + rp1d, + rp1g, + rp1a, + rp2g, + rp2a, + rph, + rp2gate, + rus, + rpz, + sm, + prec, + ect, + ns, + ) in itertools.product( + ["concat", "strip"], # repinit_tebd_input_mode + [ + True, + ], # repinit_set_davg_zero + [ + True, + ], # repformer_update_g1_has_conv + [ + True, + ], # repformer_update_g1_has_drrd + [ + True, + ], # repformer_update_g1_has_grrg + [ + True, + ], # repformer_update_g1_has_attn + [ + True, + ], # repformer_update_g2_has_g1g1 + [ + True, + ], # repformer_update_g2_has_attn + [ + False, + ], # repformer_update_h2 + [ + True, + ], # repformer_attn2_has_gate + ["res_avg", "res_residual"], # repformer_update_style + [ + True, + ], # repformer_set_davg_zero + [ + True, + ], # smooth + ["float64"], # precision + [False, True], # use_econf_tebd + [True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode=riti, + set_davg_zero=riz, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=rp1c, + update_g1_has_drrd=rp1d, + update_g1_has_grrg=rp1g, + update_g1_has_attn=rp1a, + update_g2_has_g1g1=rp2g, + update_g2_has_attn=rp2a, + update_h2=rph, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=rp2gate, + update_style=rus, + set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, + ) + + # dpa2 new impl + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + # kwargs for descriptor + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repinit.mean = paddle.to_tensor(davg, dtype=dtype).to(device=env.DEVICE) + dd0.repinit.stddev = paddle.to_tensor(dstd, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.mean = paddle.to_tensor(davg_2, dtype=dtype).to( + device=env.DEVICE + ) + dd0.repformers.stddev = paddle.to_tensor(dstd_2, dtype=dtype).to( + device=env.DEVICE + ) + model = paddle.jit.to_static(dd0) diff --git a/source/tests/pd/model/test_forward_lower.py b/source/tests/pd/model/test_forward_lower.py index db6497b605..1d924e2d3d 100644 --- a/source/tests/pd/model/test_forward_lower.py +++ b/source/tests/pd/model/test_forward_lower.py @@ -140,22 +140,21 @@ def test( class TestEnergyModelSeA(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_se_e2_a) self.model = get_model(model_params).to(env.DEVICE) class TestEnergyModelDPA1(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_dpa1) self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_dpa2) self.model = get_model(model_params).to(env.DEVICE) @@ -163,7 +162,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelZBL(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_zbl) self.model = get_model(model_params).to(env.DEVICE) @@ -171,7 +170,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinSeA(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) self.test_spin = True @@ -180,7 +179,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinDPA1(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"] @@ -192,7 +191,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyModelSpinDPA2(unittest.TestCase, ForwardLowerTest): - def setUp(self): + def setUp(self) -> None: self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"] diff --git a/source/tests/pd/model/test_null_input.py b/source/tests/pd/model/test_null_input.py index 5d67491943..29d2f84eea 100644 --- a/source/tests/pd/model/test_null_input.py +++ b/source/tests/pd/model/test_null_input.py @@ -23,6 +23,7 @@ ) from .test_permutation import ( model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -32,7 +33,7 @@ class NullTest: def test_nloc_1( self, - ): + ) -> None: natoms = 1 generator = paddle.seed(GLOBAL_SEED) # paddle.seed(1000) @@ -60,7 +61,7 @@ def test_nloc_1( def test_nloc_2_far( self, - ): + ) -> None: natoms = 2 generator = paddle.seed(GLOBAL_SEED) cell = paddle.rand([3, 3], dtype=dtype).to(device=env.DEVICE) @@ -100,3 +101,10 @@ def setUp(self): model_params = copy.deepcopy(model_dpa1) self.type_split = True self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, NullTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pd/model/test_permutation.py b/source/tests/pd/model/test_permutation.py index 4543348d3b..297614b45d 100644 --- a/source/tests/pd/model/test_permutation.py +++ b/source/tests/pd/model/test_permutation.py @@ -416,7 +416,6 @@ def setUp(self) -> None: self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, PermutationTest): def setUp(self) -> None: model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_rot.py b/source/tests/pd/model/test_rot.py index 85c90dc60f..84a0d3d724 100644 --- a/source/tests/pd/model/test_rot.py +++ b/source/tests/pd/model/test_rot.py @@ -176,7 +176,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, RotTest): def setUp(self): model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_rot_denoise.py b/source/tests/pd/model/test_rot_denoise.py index 74d5d41791..4a1841d10b 100644 --- a/source/tests/pd/model/test_rot_denoise.py +++ b/source/tests/pd/model/test_rot_denoise.py @@ -18,8 +18,9 @@ from ..common import ( eval_model, ) -from .test_permutation_denoise import ( # model_dpa2, +from .test_permutation_denoise import ( model_dpa1, + model_dpa2, ) dtype = paddle.float64 @@ -112,6 +113,14 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) +@unittest.skip("support of the denoise is temporally disabled") +class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + # @unittest.skip("hybrid not supported at the moment") # class TestEnergyModelHybrid(unittest.TestCase, TestRotDenoise): # def setUp(self): diff --git a/source/tests/pd/model/test_smooth.py b/source/tests/pd/model/test_smooth.py index cc50043ad8..f907e6f4ee 100644 --- a/source/tests/pd/model/test_smooth.py +++ b/source/tests/pd/model/test_smooth.py @@ -20,6 +20,7 @@ ) from .test_permutation import ( # model_dpau, model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -189,6 +190,36 @@ def setUp(self): self.aprec = 1e-5 +class TestEnergyModelDPA2(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["repinit"]["rcut"] = 8 + model_params["descriptor"]["repinit"]["rcut_smth"] = 3.5 + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = 1e-5, 1e-4 + + +class TestEnergyModelDPA2_1(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + +class TestEnergyModelDPA2_2(unittest.TestCase, SmoothTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + self.epsilon, self.aprec = None, None + + # class TestEnergyFoo(unittest.TestCase): # def test(self): # model_params = model_dpau diff --git a/source/tests/pd/model/test_trans.py b/source/tests/pd/model/test_trans.py index 3fae49d598..f050596996 100644 --- a/source/tests/pd/model/test_trans.py +++ b/source/tests/pd/model/test_trans.py @@ -110,7 +110,6 @@ def setUp(self): self.model = get_model(model_params).to(env.DEVICE) -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(unittest.TestCase, TransTest): def setUp(self): model_params = copy.deepcopy(model_dpa2) diff --git a/source/tests/pd/model/test_unused_params.py b/source/tests/pd/model/test_unused_params.py new file mode 100644 index 0000000000..bf92171da1 --- /dev/null +++ b/source/tests/pd/model/test_unused_params.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import paddle + +from deepmd.pd.model.model import ( + get_model, +) +from deepmd.pd.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) +from ..common import ( + eval_model, +) +from .test_permutation import ( + model_dpa2, +) + +dtype = paddle.float64 + + +@unittest.skip("paddle do not support unpacking grad_fn.next_functions") +class TestUnusedParamsDPA2(unittest.TestCase): + def test_unused(self): + import itertools + + for conv, drrd, grrg, attn1, g1g1, attn2, h2 in itertools.product( + [True], + [True], + [True], + [True], + [True], + [True], + [True], + ): + if (not drrd) and (not grrg) and h2: + # skip the case h2 is not envolved + continue + if (not grrg) and (not conv): + # skip the case g2 is not envolved + continue + model = copy.deepcopy(model_dpa2) + model["descriptor"]["repformer"]["nlayers"] = 2 + # model["descriptor"]["combine_grrg"] = cmbg2 + model["descriptor"]["repformer"]["update_g1_has_conv"] = conv + model["descriptor"]["repformer"]["update_g1_has_drrd"] = drrd + model["descriptor"]["repformer"]["update_g1_has_grrg"] = grrg + model["descriptor"]["repformer"]["update_g1_has_attn"] = attn1 + model["descriptor"]["repformer"]["update_g2_has_g1g1"] = g1g1 + model["descriptor"]["repformer"]["update_g2_has_attn"] = attn2 + model["descriptor"]["repformer"]["update_h2"] = h2 + model["fitting_net"]["neuron"] = [12, 12, 12] + self._test_unused(model) + + def _test_unused(self, model_params): + self.model = get_model(model_params).to(env.DEVICE) + natoms = 5 + generator = paddle.seed(GLOBAL_SEED) + cell = paddle.rand([3, 3], dtype=dtype).to(device=env.DEVICE) + cell = (cell + cell.T) + 5.0 * paddle.eye(3).to(device=env.DEVICE) + coord = paddle.rand([natoms, 3], dtype=dtype).to(device=env.DEVICE) + coord = paddle.matmul(coord, cell) + atype = paddle.to_tensor([0, 0, 0, 1, 1]).to(env.DEVICE) + idx_perm = [1, 0, 4, 3, 2] + result_0 = eval_model(self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype) + test_keys = ["energy", "force", "virial"] + ret0 = {key: result_0[key].squeeze(0) for key in test_keys} + + # use computation graph to find all contributing tensors + def get_contributing_params(y, top_level=True): + nf = y.grad_fn.next_functions if top_level else y.next_functions + for f, _ in nf: + try: + yield f.variable + except AttributeError: + pass # node has no tensor + if f is not None: + yield from get_contributing_params(f, top_level=False) + + contributing_parameters = set(get_contributing_params(ret0["energy"])) + all_parameters = set(self.model.parameters()) + non_contributing = all_parameters - contributing_parameters + self.assertEqual(len(non_contributing), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pd/model/water/multitask.json b/source/tests/pd/model/water/multitask.json index 83524a8b77..2786afca59 100644 --- a/source/tests/pd/model/water/multitask.json +++ b/source/tests/pd/model/water/multitask.json @@ -10,7 +10,8 @@ "type": "se_e2_a", "sel": [ 46, - 92 + 92, + 4 ], "rcut_smth": 0.50, "rcut": 6.00, diff --git a/source/tests/pd/model/water/multitask_sharefit.json b/source/tests/pd/model/water/multitask_sharefit.json index 246b5992f7..934ef04998 100644 --- a/source/tests/pd/model/water/multitask_sharefit.json +++ b/source/tests/pd/model/water/multitask_sharefit.json @@ -91,14 +91,14 @@ "stat_file": "./stat_files/model_1.hdf5", "training_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" }, "validation_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" @@ -108,14 +108,14 @@ "stat_file": "./stat_files/model_2.hdf5", "training_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" }, "validation_data": { "systems": [ - "pt/water/data/data_0" + "pd/water/data/data_0" ], "batch_size": 1, "_comment": "that's all" diff --git a/source/tests/pd/test_finetune.py b/source/tests/pd/test_finetune.py index f82f7a8cd0..769ea6f6d3 100644 --- a/source/tests/pd/test_finetune.py +++ b/source/tests/pd/test_finetune.py @@ -197,7 +197,7 @@ def test_finetune_change_out_bias(self): self.tearDown() - def test_finetune_change_type(self): + def test_finetune_change_type(self) -> None: if not self.mixed_types: # skip when not mixed_types return @@ -284,7 +284,7 @@ def test_finetune_change_type(self): self.tearDown() - def tearDown(self): + def tearDown(self) -> None: for f in os.listdir("."): if f.startswith("model") and f.endswith(".pd"): os.remove(f) @@ -295,7 +295,7 @@ def tearDown(self): class TestEnergyModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -311,7 +311,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyZBLModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -327,7 +327,7 @@ def setUp(self): @unittest.skip("Skip for not implemented yet") class TestEnergyDOSModelSeA(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "dos/input.json") with open(input_json) as f: self.config = json.load(f) @@ -342,7 +342,7 @@ def setUp(self): class TestEnergyModelDPA1(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) @@ -356,9 +356,8 @@ def setUp(self): self.testkey = None -@unittest.skip("Skip for not implemented yet") class TestEnergyModelDPA2(FinetuneTest, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: input_json = str(Path(__file__).parent / "water/se_atten.json") with open(input_json) as f: self.config = json.load(f) diff --git a/source/tests/pd/test_multitask.py b/source/tests/pd/test_multitask.py index d59990dcca..72ad251068 100644 --- a/source/tests/pd/test_multitask.py +++ b/source/tests/pd/test_multitask.py @@ -30,6 +30,8 @@ from .model.test_permutation import ( model_dpa1, + model_dpa2, + model_dpa2tebd, model_se_e2_a, ) @@ -40,6 +42,13 @@ def setUpModule() -> None: with open(multitask_template_json) as f: multitask_template = json.load(f) + global multitask_sharefit_template + multitask_sharefit_template_json = str( + Path(__file__).parent / "water/multitask_sharefit.json" + ) + with open(multitask_sharefit_template_json) as f: + multitask_sharefit_template = json.load(f) + class MultiTaskTrainTest: def test_multitask_train(self) -> None: @@ -227,6 +236,46 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskSeASharefit(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_se_e2_a = deepcopy(multitask_sharefit_template) + multitask_se_e2_a["model"]["shared_dict"]["my_descriptor"] = model_se_e2_a[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_se_e2_a + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + self.share_fitting = True + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): def setUp(self) -> None: multitask_DPA1 = deepcopy(multitask_template) @@ -266,5 +315,83 @@ def tearDown(self) -> None: MultiTaskTrainTest.tearDown(self) +class TestMultiTaskDPA2(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_DPA2 = deepcopy(multitask_template) + multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA2" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA2 + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + +class TestMultiTaskDPA2Tebd(unittest.TestCase, MultiTaskTrainTest): + def setUp(self) -> None: + multitask_DPA2 = deepcopy(multitask_template) + multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2tebd[ + "descriptor" + ] + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "DPA2Tebd" + os.makedirs(self.stat_files, exist_ok=True) + self.config = multitask_DPA2 + self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_1"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_1"]["stat_file"] = ( + f"{self.stat_files}/model_1" + ) + self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = ( + data_file + ) + self.config["training"]["data_dict"]["model_2"]["validation_data"][ + "systems" + ] = data_file + self.config["training"]["data_dict"]["model_2"]["stat_file"] = ( + f"{self.stat_files}/model_2" + ) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.origin_config = deepcopy(self.config) + self.config["model"], self.shared_links = preprocess_shared_params( + self.config["model"] + ) + + def tearDown(self) -> None: + MultiTaskTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index c3d65c09df..8958dcb165 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -24,6 +24,7 @@ from .model.test_permutation import ( model_dpa1, + model_dpa2, model_se_e2_a, ) @@ -195,5 +196,21 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestEnergyModelDPA2(unittest.TestCase, DPTrainTest): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_dpa2) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pd/test_update_sel.py b/source/tests/pd/test_update_sel.py index e7b1acf6ff..10342357c6 100644 --- a/source/tests/pd/test_update_sel.py +++ b/source/tests/pd/test_update_sel.py @@ -31,7 +31,7 @@ def setUp(self) -> None: return super().setUp() @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_one_sel(self, sel_mock): + def test_update_one_sel(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] min_nbor_dist, sel = self.update_sel.update_one_sel(None, None, 6, "auto") @@ -45,7 +45,7 @@ def test_update_one_sel(self, sel_mock): @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_hybrid(self, sel_mock): + def test_update_sel_hybrid(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] jdata = { @@ -76,7 +76,7 @@ def test_update_sel_hybrid(self, sel_mock): self.assertEqual(jdata, expected_out) @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel(self, sel_mock): + def test_update_sel(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [10, 20] jdata = { @@ -90,9 +90,8 @@ def test_update_sel(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_auto(self, sel_mock): + def test_update_sel_atten_auto(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -118,9 +117,8 @@ def test_update_sel_atten_auto(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_int(self, sel_mock): + def test_update_sel_atten_int(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -146,9 +144,8 @@ def test_update_sel_atten_int(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - @unittest.skip("Skip for not implemented yet") @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") - def test_update_sel_atten_list(self, sel_mock): + def test_update_sel_atten_list(self, sel_mock) -> None: sel_mock.return_value = self.mock_min_nbor_dist, [25] jdata = { @@ -174,7 +171,50 @@ def test_update_sel_atten_list(self, sel_mock): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - def test_skip_frozen(self): + @patch("deepmd.pd.utils.update_sel.UpdateSel.get_nbor_stat") + def test_update_sel_dpa2_auto(self, sel_mock) -> None: + sel_mock.return_value = self.mock_min_nbor_dist, [25] + + jdata = { + "model": { + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "nsel": "auto", + "three_body_rcut": 4.0, + "three_body_sel": "auto", + }, + "repformer": { + "rcut": 4.0, + "nsel": "auto", + }, + } + }, + "training": {"training_data": {}}, + } + expected_out = { + "model": { + "descriptor": { + "type": "dpa2", + "repinit": { + "rcut": 6.0, + "nsel": 28, + "three_body_rcut": 4.0, + "three_body_sel": 28, + }, + "repformer": { + "rcut": 4.0, + "nsel": 28, + }, + } + }, + "training": {"training_data": {}}, + } + jdata = update_sel(jdata) + self.assertEqual(jdata, expected_out) + + def test_skip_frozen(self) -> None: jdata = { "model": { "type": "frozen", @@ -185,7 +225,7 @@ def test_skip_frozen(self): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) - def test_wrap_up_4(self): + def test_wrap_up_4(self) -> None: self.assertEqual(self.update_sel.wrap_up_4(12), 3 * 4) self.assertEqual(self.update_sel.wrap_up_4(13), 4 * 4) self.assertEqual(self.update_sel.wrap_up_4(14), 4 * 4)