From e23dc5f0e43c77ff822e0a01c9d003acdfa6fe88 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 18 Dec 2024 02:01:24 +0800 Subject: [PATCH 1/5] add dpa3 alpha --- deepmd/dpmodel/descriptor/dpa3.py | 110 +++ deepmd/pt/model/descriptor/__init__.py | 4 + deepmd/pt/model/descriptor/dpa3.py | 563 ++++++++++++ deepmd/pt/model/descriptor/repflow_layer.py | 913 ++++++++++++++++++++ deepmd/pt/model/descriptor/repflows.py | 570 ++++++++++++ deepmd/pt/utils/utils.py | 2 + deepmd/utils/argcheck.py | 159 ++++ 7 files changed, 2321 insertions(+) create mode 100644 deepmd/dpmodel/descriptor/dpa3.py create mode 100644 deepmd/pt/model/descriptor/dpa3.py create mode 100644 deepmd/pt/model/descriptor/repflow_layer.py create mode 100644 deepmd/pt/model/descriptor/repflows.py diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py new file mode 100644 index 0000000000..228c652930 --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +class RepFlowArgs: + def __init__( + self, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + nlayers: int = 6, + e_rcut: float = 6.0, + e_rcut_smth: float = 5.0, + e_sel: int = 120, + a_rcut: float = 4.0, + a_rcut_smth: float = 3.5, + a_sel: int = 20, + axis_neuron: int = 4, + node_has_conv: bool = False, + update_angle: bool = True, + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + ) -> None: + r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + """ + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.nlayers = nlayers + self.e_rcut = e_rcut + self.e_rcut_smth = e_rcut_smth + self.e_sel = e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.axis_neuron = axis_neuron + self.node_has_conv = node_has_conv # tmp + self.update_angle = update_angle + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) + + def serialize(self) -> dict: + return { + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "nlayers": self.nlayers, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "axis_neuron": self.axis_neuron, + "node_has_conv": self.node_has_conv, # tmp + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + } + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowArgs": + return cls(**data) diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 4a227918fe..9f3468d1db 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -13,6 +13,9 @@ from .dpa2 import ( DescrptDPA2, ) +from .dpa3 import ( + DescrptDPA3, +) from .env_mat import ( prod_env_mat, ) @@ -49,6 +52,7 @@ "DescrptBlockSeTTebd", "DescrptDPA1", "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py new file mode 100644 index 0000000000..5d785e0de9 --- /dev/null +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -0,0 +1,563 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repflow_layer import ( + RepFlowLayer, +) +from .repflows import ( + DescrptBlockRepflows, +) + + +@BaseDescriptor.register("dpa3") +class DescrptDPA3(BaseDescriptor, torch.nn.Module): + def __init__( + self, + ntypes: int, + # args for repflow + repflow: Union[RepFlowArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = False, + activation_function: str = "silu", + precision: str = "float64", + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-3 descriptor. + + Parameters + ---------- + repflow : Union[RepFlowArgs, dict] + The arguments used to initialize the repflow block, see docstr in `RepFlowArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + activation_function : str, optional + The activation function in the embedding net. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nb x nloc x n_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repflow_args = init_subclass_params(repflow, RepFlowArgs) + self.activation_function = activation_function + + self.repflows = DescrptBlockRepflows( + self.repflow_args.e_rcut, + self.repflow_args.e_rcut_smth, + self.repflow_args.e_sel, + self.repflow_args.a_rcut, + self.repflow_args.a_rcut_smth, + self.repflow_args.a_sel, + ntypes, + nlayers=self.repflow_args.nlayers, + n_dim=self.repflow_args.n_dim, + e_dim=self.repflow_args.e_dim, + a_dim=self.repflow_args.a_dim, + axis_neuron=self.repflow_args.axis_neuron, + node_has_conv=self.repflow_args.node_has_conv, + update_angle=self.repflow_args.update_angle, + activation_function=self.activation_function, + update_style=self.repflow_args.update_style, + update_residual=self.repflow_args.update_residual, + update_residual_init=self.repflow_args.update_residual_init, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + seed=child_seed(seed, 1), + ) + + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.tebd_dim = self.repflow_args.n_dim + self.type_embedding = TypeEmbedNet( + ntypes, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + + assert self.repflows.e_rcut > self.repflows.a_rcut + assert self.repflows.e_sel > self.repflows.a_sel + + self.rcut = self.repflows.get_rcut() + self.rcut_smth = self.repflows.get_rcut_smth() + self.sel = self.repflows.get_sel() + self.ntypes = ntypes + + # set trainable + for param in self.parameters(): + param.requires_grad = trainable + self.compress = False + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repflows.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repflows.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.repflows.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.repflows.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA3 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repflow + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.repflows.share_params(base_class.repflow, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repflow = self.repflows + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repflow, + type_map, + des_with_stat=model_with_new_type_stat.repflow + if model_with_new_type_stat is not None + else None, + ) + repflow.ntypes = self.ntypes + repflow.reinit_exclude(self.exclude_types) + repflow["davg"] = repflow["davg"][remap_index] + repflow["dstd"] = repflow["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: list[torch.Tensor], + stddev: list[torch.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repflows.mean] + stddev_list = [self.repflows.stddev] + return mean_list, stddev_list + + def serialize(self) -> dict: + repflows = self.repflows + data = { + "@class": "Descriptor", + "type": "dpa3", + "@version": 1, + "ntypes": self.ntypes, + "repflow_args": self.repflow_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "activation_function": self.activation_function, + "precision": self.precision, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.embedding.serialize(), + } + repflow_variable = { + "edge_embd": repflows.edge_embd.serialize(), + "repflow_layers": [layer.serialize() for layer in repflows.layers], + "env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repflows["davg"]), + "dstd": to_numpy_array(repflows["dstd"]), + }, + } + data.update( + { + "repflow_variable": repflow_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA3": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 1, 1) + data.pop("@class") + data.pop("type") + repflow_variable = data.pop("repflow_variable").copy() + type_embedding = data.pop("type_embedding") + data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.repflows.prec, device=env.DEVICE) + + # deserialize repflow + statistic_repflows = repflow_variable.pop("@variables") + env_mat = repflow_variable.pop("env_mat") + repflow_layers = repflow_variable.pop("repflow_layers") + obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd")) + obj.repflows["davg"] = t_cvt(statistic_repflows["davg"]) + obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"]) + obj.repflows.layers = torch.nn.ModuleList( + [RepFlowLayer.deserialize(layer) for layer in repflow_layers] + ) + return obj + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + node_embd + The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim) + rot_mat + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x e_dim x 3 + edge_embd + The edge embedding. + shape: nf x nloc x nnei x e_dim + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + + node_embd_ext = self.type_embedding(extended_atype) + node_embd_inp = node_embd_ext[:, :nloc, :] + # repflows + node_embd, edge_embd, h2, rot_mat, sw = self.repflows( + nlist, + extended_coord, + extended_atype, + node_embd_ext, + mapping, + comm_dict=comm_dict, + ) + if self.concat_output_tebd: + node_embd = torch.cat([node_embd, node_embd_inp], dim=-1) + return ( + node_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + edge_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repflow_e_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["e_rcut"], + local_jdata_cpy["repflow"]["e_sel"], + True, + ) + local_jdata_cpy["repflow"]["e_sel"] = repflow_e_sel[0] + + min_nbor_dist, repflow_a_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["a_rcut"], + local_jdata_cpy["repflow"]["a_sel"], + True, + ) + local_jdata_cpy["repflow"]["a_sel"] = repflow_a_sel[0] + + return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("Compression is unsupported for DPA3.") diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py new file mode 100644 index 0000000000..2395986366 --- /dev/null +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -0,0 +1,913 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.init import ( + constant_, + normal_, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, + get_generator, + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +def get_residual( + _dim: int, + _scale: float, + _mode: str = "norm", + trainable: bool = True, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, +) -> torch.Tensor: + r""" + Get residual tensor for one update vector. + + Parameters + ---------- + _dim : int + The dimension of the update vector. + _scale + The initial scale of the residual tensor. See `_mode` for details. + _mode + The mode of residual initialization for the residual tensor. + - "norm" (default): init residual using normal with `_scale` std. + - "const": init residual using element-wise constants of `_scale`. + trainable + Whether the residual tensor is trainable. + precision + The precision of the residual tensor. + seed : int, optional + Random seed for parameter initialization. + """ + random_generator = get_generator(seed) + residual = nn.Parameter( + data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), + requires_grad=trainable, + ) + if _mode == "norm": + normal_(residual.data, std=_scale, generator=random_generator) + elif _mode == "const": + constant_(residual.data, val=_scale) + else: + raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") + return residual + + +# common ops +def _make_nei_g1( + g1_ext: torch.Tensor, + nlist: torch.Tensor, +) -> torch.Tensor: + """ + Make neighbor-wise atomic invariant rep. + + Parameters + ---------- + g1_ext + Extended atomic invariant rep, with shape nb x nall x ng1. + nlist + Neighbor list, with shape nb x nloc x nnei. + + Returns + ------- + gg1: torch.Tensor + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + + """ + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 + ng1 = g1_ext.shape[-1] + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 + gg1 = torch.gather(g1_ext, dim=1, index=index) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) + return gg1 + + +def _apply_nlist_mask( + gg: torch.Tensor, + nlist_mask: torch.Tensor, +) -> torch.Tensor: + """ + Apply nlist mask to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # msk: nf x nloc x nnei + return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) + + +def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: + """ + Apply switch function to neighbor-wise rep tensors. + + Parameters + ---------- + gg + Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nf x nloc x nnei. + """ + # gg: nf x nloc x nnei x d + # sw: nf x nloc x nnei + return gg * sw.unsqueeze(-1) + + +class RepFlowLayer(torch.nn.Module): + def __init__( + self, + e_rcut: float, + e_rcut_smth: float, + e_sel: int, + a_rcut: float, + a_rcut_smth: float, + a_sel: int, + ntypes: int, + n_dim: int = 128, + e_dim: int = 16, + a_dim: int = 64, + axis_neuron: int = 4, + update_angle: bool = True, # angle + update_g1_has_conv: bool = True, + activation_function: str = "silu", + update_style: str = "res_avg", + update_residual: float = 0.001, + update_residual_init: str = "norm", + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.ntypes = ntypes + e_sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(e_sel) + assert len(e_sel) == 1 + self.e_sel = e_sel + self.sec = self.e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.activation_function = activation_function + self.act = ActivationFn(activation_function) + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.precision = precision + self.seed = seed + self.prec = PRECISION_DICT[precision] + + self.update_g1_has_conv = update_g1_has_conv + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.g1_residual = [] + self.g2_residual = [] + self.h2_residual = [] + self.a_residual = [] + self.proj_g1g2 = None + self.edge_info_dim = self.n_dim * 2 + self.e_dim + + # g1 self mlp + self.node_self_mlp = MLPLayer( + n_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 15), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + + # g1 conv # tmp + if self.update_g1_has_conv: + self.proj_g1g2 = MLPLayer( + e_dim, + n_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) + + # g1 sym + self.g1_sym_dim = self.cal_1_dim(n_dim, e_dim, self.axis_neuron) + self.linear1 = MLPLayer( + self.g1_sym_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 1), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 0), + ) + ) + + # g1 edge + self.g1_edge_linear1 = MLPLayer( + self.edge_info_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 11), + ) # need act + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 13), + ) + ) + + # g2 edge + self.linear2 = MLPLayer( + self.edge_info_dim, + e_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + + if self.update_angle: + angle_seed = 20 + self.angle_dim = self.a_dim + self.n_dim + 2 * self.e_dim + self.angle_linear = MLPLayer( + self.angle_dim, + self.a_dim, + precision=precision, + seed=child_seed(seed, angle_seed + 1), + ) # need act + if self.update_style == "res_residual": + self.a_residual.append( + get_residual( + self.a_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, angle_seed + 2), + ) + ) + + self.g2_angle_linear1 = MLPLayer( + self.angle_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, angle_seed + 3), + ) # need act + self.g2_angle_linear2 = MLPLayer( + self.e_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, angle_seed + 4), + ) + if self.update_style == "res_residual": + self.g2_residual.append( + get_residual( + self.e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, angle_seed + 5), + ) + ) + + else: + self.angle_linear = None + self.g2_angle_linear1 = None + self.g2_angle_linear2 = None + self.angle_dim = 0 + self.angle_dim = 0 + + self.g1_residual = nn.ParameterList(self.g1_residual) + self.g2_residual = nn.ParameterList(self.g2_residual) + self.h2_residual = nn.ParameterList(self.h2_residual) + self.a_residual = nn.ParameterList(self.a_residual) + + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: + ret = g2d * ax + g1d * ax + return ret + + def _update_g1_conv( + self, + gg1: torch.Tensor, + g2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate the convolution update for atomic invariant rep. + + Parameters + ---------- + gg1 + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. + g2 + Pair invariant rep, with shape nb x nloc x nnei x ng2. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + """ + assert self.proj_g1g2 is not None + nb, nloc, nnei, _ = g2.shape + ng1 = gg1.shape[-1] + ng2 = g2.shape[-1] + gg1 = gg1.view(nb, nloc, nnei, ng1) + # nb x nloc x nnei x ng2/ng1 + gg1 = _apply_nlist_mask(gg1, nlist_mask) + gg1 = _apply_switch(gg1, sw) + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device + ) + g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1) + # nb x nloc x ng1 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + return g1_11 + + @staticmethod + def _cal_hg( + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + smooth: bool = True, + epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + ng2 = g2.shape[-1] + # nb x nloc x nnei x ng2 + g2 = _apply_nlist_mask(g2, nlist_mask) + if not smooth: + # nb x nloc + # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy + if not use_sqrt_nnei: + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + else: + invnnei = 1.0 / ( + epsilon + torch.sqrt(torch.sum(nlist_mask.type_as(g2), dim=-1)) + ) + # nb x nloc x 1 x 1 + invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) + else: + g2 = _apply_switch(g2, sw) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device + ) + else: + invnnei = torch.rsqrt( + float(nnei) + * torch.ones((nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device) + ) + # nb x nloc x 3 x ng2 + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + return h2g2 + + @staticmethod + def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # nb x nloc x 3 x ng2 + nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] + # nb x nloc x axis x ng2 + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2) + return g1_13 + + def symmetrization_op( + self, + g2: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + smooth: bool = True, + epsilon: float = 1e-4, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + g2 + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + smooth + Whether to use smoothness in processes such as attention weights calculation. + epsilon + Protection of 1./nnei. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + """ + # g2: nb x nloc x nnei x ng2 + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = g2.shape + # nb x nloc x 3 x ng2 + h2g2 = self._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=True, + ) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 + + def forward( + self, + g1_ext: torch.Tensor, # nf x nall x ng1 + g2: torch.Tensor, # nf x nloc x nnei x ng2 + h2: torch.Tensor, # nf x nloc x nnei x 3 + angle_embed: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + angle_nlist: torch.Tensor, # nf x nloc x a_nnei + angle_nlist_mask: torch.Tensor, # nf x nloc x a_nnei + angle_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + ): + """ + Parameters + ---------- + g1_ext : nf x nall x ng1 extended single-atom channel + g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant + h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant + nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) + nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei switch function + + Returns + ------- + g1: nf x nloc x ng1 updated single-atom channel + g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant + h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + """ + nb, nloc, nnei, _ = g2.shape + nall = g1_ext.shape[1] + g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + + g1_update: list[torch.Tensor] = [g1] + g2_update: list[torch.Tensor] = [g2] + a_update: list[torch.Tensor] = [angle_embed] + h2_update: list[torch.Tensor] = [h2] + + g1_sym: list[torch.Tensor] = [] + + # g1 self mlp + node_self_mlp = self.act(self.node_self_mlp(g1)) + g1_update.append(node_self_mlp) + + gg1 = _make_nei_g1(g1_ext, nlist) + # g1 conv # tmp + if self.update_g1_has_conv: + assert gg1 is not None + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + g1_update.append(g1_conv) + + # g1 sym mlp + g1_sym.append( + self.symmetrization_op( + g2, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=True, + epsilon=self.epsilon, + ) + ) + g1_sym.append( + self.symmetrization_op( + gg1, + h2, + nlist_mask, + sw, + self.axis_neuron, + smooth=True, + epsilon=self.epsilon, + ) + ) + g1_1 = self.act(self.linear1(torch.cat(g1_sym, dim=-1))) + g1_update.append(g1_1) + + edge_info = torch.cat( + [torch.tile(g1.unsqueeze(-2), [1, 1, self.nnei, 1]), gg1, g2], dim=-1 + ) + + # g1 edge update + # nb x nloc x nnei x ng1 + g1_edge_info = self.act(self.g1_edge_linear1(edge_info)) * sw.unsqueeze(-1) + g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei + g1_update.append(g1_edge_update) + # update g1 + g1_new = self.list_update(g1_update, "g1") + + # g2 edge update + g2_edge_info = self.act(self.linear2(edge_info)) + g2_update.append(g2_edge_info) + + if self.update_angle: + assert self.angle_linear is not None + assert self.g2_angle_linear1 is not None + assert self.g2_angle_linear2 is not None + # nb x nloc x a_nnei x a_nnei x g1 + g1_angle_embed = torch.tile( + g1.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1) + ) + # nb x nloc x a_nnei x g2 + g2_angle = g2[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x g2 + g2_angle = torch.where(angle_nlist_mask.unsqueeze(-1), g2_angle, 0.0) + # nb x nloc x (a_nnei) x a_nnei x g2 + g2_angle_i = torch.tile(g2_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + # nb x nloc x a_nnei x (a_nnei) x g2 + g2_angle_j = torch.tile(g2_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + # nb x nloc x a_nnei x a_nnei x (g2 + g2) + g2_angle_embed = torch.cat([g2_angle_i, g2_angle_j], dim=-1) + + # angle for g2: + updated_g2_angle_list = [angle_embed] + # nb x nloc x a_nnei x a_nnei x (a + g1 + g2*2) + updated_g2_angle_list += [g1_angle_embed, g2_angle_embed] + updated_g2_angle = torch.cat(updated_g2_angle_list, dim=-1) + # nb x nloc x a_nnei x a_nnei x g2 + updated_angle_g2 = self.act(self.g2_angle_linear1(updated_g2_angle)) + # nb x nloc x a_nnei x a_nnei x g2 + weighted_updated_angle_g2 = ( + updated_angle_g2 + * angle_sw[:, :, :, None, None] + * angle_sw[:, :, None, :, None] + ) + # nb x nloc x a_nnei x g2 + reduced_updated_angle_g2 = torch.sum(weighted_updated_angle_g2, dim=-2) / ( + self.a_sel**0.5 + ) + # nb x nloc x nnei x g2 + padding_updated_angle_g2 = torch.concat( + [ + reduced_updated_angle_g2, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=g2.dtype, + device=g2.device, + ), + ], + dim=2, + ) + full_mask = torch.concat( + [ + angle_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=angle_nlist_mask.dtype, + device=angle_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_updated_angle_g2 = torch.where( + full_mask.unsqueeze(-1), padding_updated_angle_g2, g2 + ) + g2_update.append(self.act(self.g2_angle_linear2(padding_updated_angle_g2))) + + # update g2 + g2_new = self.list_update(g2_update, "g2") + # angle for angle + updated_angle = updated_g2_angle + # nb x nloc x a_nnei x a_nnei x dim_a + angle_message = self.act(self.angle_linear(updated_angle)) + # angle update + a_update.append(angle_message) + else: + # update g2 + g2_new = self.list_update(g2_update, "g2") + + # update + h2_new = self.list_update(h2_update, "h2") + a_new = self.list_update(a_update, "a") + return g1_new, g2_new, h2_new, a_new + + @torch.jit.export + def list_update_res_avg( + self, + update_list: list[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update_res_residual( + self, update_list: list[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "g1": + for ii, vv in enumerate(self.g1_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "g2": + for ii, vv in enumerate(self.g2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "h2": + for ii, vv in enumerate(self.h2_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "a": + for ii, vv in enumerate(self.a_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + @torch.jit.export + def list_update( + self, update_list: list[torch.Tensor], update_name: str = "g1" + ) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "ntypes": self.ntypes, + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "axis_neuron": self.axis_neuron, + "activation_function": self.activation_function, + "update_style": self.update_style, + "precision": self.precision, + "linear1": self.linear1.serialize(), + } + if self.update_g1_has_conv: + data.update( + { + "proj_g1g2": self.proj_g1g2.serialize(), + } + ) + + if self.update_g2_has_attn or self.update_h2: + data.update( + { + "attn2g_map": self.attn2g_map.serialize(), + } + ) + if self.update_g2_has_attn: + data.update( + { + "attn2_mh_apply": self.attn2_mh_apply.serialize(), + "attn2_lm": self.attn2_lm.serialize(), + } + ) + + if self.update_h2: + data.update( + { + "attn2_ev_apply": self.attn2_ev_apply.serialize(), + } + ) + if self.update_g1_has_attn: + data.update( + { + "loc_attn": self.loc_attn.serialize(), + } + ) + if self.g1_out_mlp: + data.update( + { + "node_self_mlp": self.node_self_mlp.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 2, 1) + data.pop("@class") + linear1 = data.pop("linear1") + update_chnnl_2 = data["update_chnnl_2"] + update_g1_has_conv = data["update_g1_has_conv"] + update_g2_has_g1g1 = data["update_g2_has_g1g1"] + update_g2_has_attn = data["update_g2_has_attn"] + update_h2 = data["update_h2"] + update_g1_has_attn = data["update_g1_has_attn"] + update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] + + linear2 = data.pop("linear2", None) + proj_g1g2 = data.pop("proj_g1g2", None) + attn2g_map = data.pop("attn2g_map", None) + attn2_mh_apply = data.pop("attn2_mh_apply", None) + attn2_lm = data.pop("attn2_lm", None) + attn2_ev_apply = data.pop("attn2_ev_apply", None) + loc_attn = data.pop("loc_attn", None) + node_self_mlp = data.pop("node_self_mlp", None) + variables = data.pop("@variables", {}) + g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) + g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) + h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) + + obj = cls(**data) + obj.linear1 = MLPLayer.deserialize(linear1) + if update_chnnl_2: + assert isinstance(linear2, dict) + obj.linear2 = MLPLayer.deserialize(linear2) + if update_g1_has_conv: + assert isinstance(proj_g1g2, dict) + obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) + + if g1_out_mlp: + assert isinstance(node_self_mlp, dict) + obj.node_self_mlp = MLPLayer.deserialize(node_self_mlp) + if update_style == "res_residual": + for ii, t in enumerate(obj.g1_residual): + t.data = to_torch_tensor(g1_residual[ii]) + for ii, t in enumerate(obj.g2_residual): + t.data = to_torch_tensor(g2_residual[ii]) + for ii, t in enumerate(obj.h2_residual): + t.data = to_torch_tensor(h2_residual[ii]) + return obj diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py new file mode 100644 index 0000000000..c5a39f8b3b --- /dev/null +++ b/deepmd/pt/model/descriptor/repflows.py @@ -0,0 +1,570 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt.utils.spin import ( + concat_switch_virtual, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .repflow_layer import ( + RepFlowLayer, +) + +if not hasattr(torch.ops.deepmd, "border_op"): + + def border_op( + argument0, + argument1, + argument2, + argument3, + argument4, + argument5, + argument6, + argument7, + argument8, + ) -> torch.Tensor: + raise NotImplementedError( + "border_op is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for DPA-3 for details." + ) + + # Note: this hack cannot actually save a model that can be run using LAMMPS. + torch.ops.deepmd.border_op = border_op + + +@DescriptorBlock.register("se_repflow") +class DescrptBlockRepflows(DescriptorBlock): + def __init__( + self, + e_rcut, + e_rcut_smth, + e_sel: int, + a_rcut, + a_rcut_smth, + a_sel: int, + ntypes: int, + nlayers: int = 6, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + axis_neuron: int = 4, + node_has_conv: bool = False, + update_angle: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + set_davg_zero: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + r""" + The repflow descriptor block. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + ntypes : int + Number of element types + activation_function : str, optional + The activation function in the embedding net. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.e_sel = e_sel + self.a_rcut = float(a_rcut) + self.a_rcut_smth = float(a_rcut_smth) + self.a_sel = a_sel + self.ntypes = ntypes + self.nlayers = nlayers + # for other common desciptor method + sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.rcut = e_rcut + self.rcut_smth = e_rcut_smth + self.sec = self.sel + self.split_sel = self.sel + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.update_angle = update_angle + self.node_has_conv = node_has_conv + + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.act = ActivationFn(activation_function) + self.prec = PRECISION_DICT[precision] + self.angle_embedding = torch.nn.Linear( + in_features=1, + out_features=self.a_dim, + bias=False, + dtype=self.prec, + ) + + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.epsilon = 1e-4 + self.seed = seed + + self.edge_embd = MLPLayer( + 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepFlowLayer( + e_rcut=self.e_rcut, + e_rcut_smth=self.e_rcut_smth, + e_sel=self.sel, + a_rcut=self.a_rcut, + a_rcut_smth=self.a_rcut_smth, + a_sel=self.a_sel, + ntypes=self.ntypes, + n_dim=self.n_dim, + e_dim=self.e_dim, + a_dim=self.a_dim, + axis_neuron=self.axis_neuron, + update_g1_has_conv=self.node_has_conv, # tmp + update_angle=self.update_angle, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + precision=precision, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = torch.nn.ModuleList(layers) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.e_rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.e_rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension g2.""" + return self.e_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.n_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.n_dim + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + if comm_dict is None: + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = torch.where(exclude_mask != 0, nlist, -1) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.e_rcut, + self.e_rcut_smth, + protection=self.env_protection, + ) + nlist_mask = nlist != -1 + sw = torch.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + if comm_dict is None: + assert isinstance(extended_atype_embd, torch.Tensor) # for jit + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + else: + atype_embd = extended_atype_embd + assert isinstance(atype_embd, torch.Tensor) # for jit + g1 = self.act(atype_embd) + ng1 = g1.shape[-1] + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) + # nb x nloc x nnei x ng2 + g2 = self.act(self.edge_embd(g2)) + + # get angle nlist (maybe smaller) + a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ + :, :, : self.a_sel + ] + angle_nlist = nlist[:, :, : self.a_sel] + angle_nlist = torch.where(a_dist_mask, angle_nlist, -1) + _, angle_diff, angle_sw = prod_env_mat( + extended_coord, + angle_nlist, + atype, + self.mean[:, : self.a_sel], + self.stddev[:, : self.a_sel], + self.a_rcut, + self.a_rcut_smth, + protection=self.env_protection, + ) + angle_nlist_mask = angle_nlist != -1 + angle_sw = torch.squeeze(angle_sw, -1) + # beyond the cutoff sw should be 0.0 + angle_sw = angle_sw.masked_fill(~angle_nlist_mask, 0.0) + angle_nlist[angle_nlist == -1] = 0 + + # nf x nloc x a_nnei x 3 + normalized_diff_i = angle_diff / ( + torch.linalg.norm(angle_diff, dim=-1, keepdim=True) + 1e-6 + ) + # nf x nloc x 3 x a_nnei + normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) + # nf x nloc x a_nnei x a_nnei + # 1 - 1e-6 for torch.acos stability + cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + # nf x nloc x a_nnei x a_nnei x 1 + cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + # nf x nloc x a_nnei x a_nnei x a_dim + angle_embed = self.angle_embedding(cosine_ij).reshape( + nframes, nloc, self.a_sel, self.a_sel, self.a_dim + ) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) + ) + for idx, ll in enumerate(self.layers): + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 + if comm_dict is None: + assert mapping is not None + g1_ext = torch.gather(g1, 1, mapping) + else: + has_spin = "has_spin" in comm_dict + if not has_spin: + n_padding = nall - nloc + g1 = torch.nn.functional.pad( + g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 + ) + real_nloc = nloc + real_nall = nall + else: + # for spin + real_nloc = nloc // 2 + real_nall = nall // 2 + real_n_padding = real_nall - real_nloc + g1_real, g1_virtual = torch.split(g1, [real_nloc, real_nloc], dim=1) + # mix_g1: nb x real_nloc x (ng1 * 2) + mix_g1 = torch.cat([g1_real, g1_virtual], dim=2) + # nb x real_nall x (ng1 * 2) + g1 = torch.nn.functional.pad( + mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + ) + + assert "send_list" in comm_dict + assert "send_proc" in comm_dict + assert "recv_proc" in comm_dict + assert "send_num" in comm_dict + assert "recv_num" in comm_dict + assert "communicator" in comm_dict + ret = torch.ops.deepmd.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + g1, + comm_dict["communicator"], + torch.tensor( + real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + torch.tensor( + real_nall - real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + ) + g1_ext = ret[0].unsqueeze(0) + if has_spin: + g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2) + g1_ext = concat_switch_virtual( + g1_real_ext, g1_virtual_ext, real_nloc + ) + g1, g2, h2, angle_embed = ll.forward( + g1_ext, + g2, + h2, + angle_embed, + nlist, + nlist_mask, + sw, + angle_nlist, + angle_nlist_mask, + angle_sw, + ) + + # nb x nloc x 3 x ng2 + h2g2 = RepFlowLayer._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=True, + epsilon=self.epsilon, + use_sqrt_nnei=True, + ) + # (nb x nloc) x ng2 x 3 + rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) + + return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_( + torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) + ) + self.stddev.copy_( + torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) + ) + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return True diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 6ce4f5d6fc..50d378455b 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -39,6 +39,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.softplus(x) elif self.activation.lower() == "sigmoid": return torch.sigmoid(x) + elif self.activation.lower() == "silu": + return F.silu(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b57f15979..0ac084faf4 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1355,6 +1355,165 @@ def dpa2_repformer_args(): ] +@descrpt_args_plugin.register("dpa3", doc=doc_only_pt_supported) +def descrpt_dpa3_args(): + # repflow args + doc_repflow = "The arguments used to initialize the repflow block." + # descriptor args + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." + ) + doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." + doc_trainable = "If the parameters in the embedding net is trainable." + doc_seed = "Random seed for parameter initialization." + doc_use_econf_tebd = "Whether to use electronic configuration type embedding." + doc_use_tebd_bias = "Whether to use bias in the type embedding layer." + return [ + # doc_repflow args + Argument("repflow", dict, dpa3_repflow_args(), doc=doc_repflow), + # descriptor args + Argument( + "concat_output_tebd", + bool, + optional=True, + default=False, + doc=doc_concat_output_tebd, + ), + Argument( + "activation_function", + str, + optional=True, + default="silu", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument( + "exclude_types", + list[list[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "use_econf_tebd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_econf_tebd, + ), + Argument( + "use_tebd_bias", + bool, + optional=True, + default=False, + doc=doc_use_tebd_bias, + ), + ] + + +# repflow for dpa3 +def dpa3_repflow_args(): + # repflow args + doc_n_dim = "The dimension of node representation." + doc_e_dim = "The dimension of edge representation." + doc_a_dim = "The dimension of angle representation." + doc_nlayers = "The number of repflow layers." + doc_e_rcut = "The edge cut-off radius." + doc_e_rcut_smth = "Where to start smoothing for edge. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_e_sel = 'Maximally possible number of selected edge neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_rcut = "The angle cut-off radius." + doc_a_rcut_smth = "Where to start smoothing for angle. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops." + doc_update_angle = ( + "Where to update the angle rep. If not, only node and edge rep will be used." + ) + doc_update_style = ( + "Style to update a representation. " + "Supported options are: " + "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " + "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + "-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) " + "where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` " + "and `update_residual_init`." + ) + doc_update_residual = ( + "When update using residual mode, " + "the initial std of residual vector weights." + ) + doc_update_residual_init = ( + "When update using residual mode, " + "the initialization mode of residual vector weights." + "Supported modes are: ['norm', 'const']." + ) + + return [ + # repflow args + Argument("n_dim", int, optional=True, default=128, doc=doc_n_dim), + Argument("e_dim", int, optional=True, default=64, doc=doc_e_dim), + Argument("a_dim", int, optional=True, default=64, doc=doc_a_dim), + Argument("nlayers", int, optional=True, default=6, doc=doc_nlayers), + Argument("e_rcut", float, doc=doc_e_rcut), + Argument("e_rcut_smth", float, doc=doc_e_rcut_smth), + Argument("e_sel", [int, str], doc=doc_e_sel), + Argument("a_rcut", float, doc=doc_a_rcut), + Argument("a_rcut_smth", float, doc=doc_a_rcut_smth), + Argument("a_sel", [int, str], doc=doc_a_sel), + Argument( + "axis_neuron", + int, + optional=True, + default=4, + doc=doc_axis_neuron, + ), + Argument("node_has_conv", bool, optional=True, default=False, doc="TMP"), + Argument( + "update_angle", + bool, + optional=True, + default=True, + doc=doc_update_angle, + ), + Argument( + "update_style", + str, + optional=True, + default="res_residual", + doc=doc_update_style, + ), + Argument( + "update_residual", + float, + optional=True, + default=0.1, + doc=doc_update_residual, + ), + Argument( + "update_residual_init", + str, + optional=True, + default="const", + doc=doc_update_residual_init, + ), + ] + + @descrpt_args_plugin.register( "se_a_ebd_v2", alias=["se_a_tpe_v2"], doc=doc_only_tf_supported ) From 527cb852757a00c2efff467d0a193cc629b0331c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:39:46 +0800 Subject: [PATCH 2/5] rename and add uts --- deepmd/dpmodel/descriptor/dpa3.py | 3 - deepmd/pt/model/descriptor/dpa3.py | 25 +- deepmd/pt/model/descriptor/repflow_layer.py | 780 +++++++----------- deepmd/pt/model/descriptor/repflows.py | 116 ++- deepmd/utils/argcheck.py | 1 - source/tests/pt/model/test_dpa3.py | 173 ++++ .../dpmodel/descriptor/test_descriptor.py | 69 ++ .../pt/descriptor/test_descriptor.py | 3 + source/tests/universal/pt/model/test_model.py | 14 + 9 files changed, 603 insertions(+), 581 deletions(-) create mode 100644 source/tests/pt/model/test_dpa3.py diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 228c652930..df1a2ae258 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -15,7 +15,6 @@ def __init__( a_rcut_smth: float = 3.5, a_sel: int = 20, axis_neuron: int = 4, - node_has_conv: bool = False, update_angle: bool = True, update_style: str = "res_residual", update_residual: float = 0.1, @@ -73,7 +72,6 @@ def __init__( self.a_rcut_smth = a_rcut_smth self.a_sel = a_sel self.axis_neuron = axis_neuron - self.node_has_conv = node_has_conv # tmp self.update_angle = update_angle self.update_style = update_style self.update_residual = update_residual @@ -98,7 +96,6 @@ def serialize(self) -> dict: "a_rcut_smth": self.a_rcut_smth, "a_sel": self.a_sel, "axis_neuron": self.axis_neuron, - "node_has_conv": self.node_has_conv, # tmp "update_angle": self.update_angle, "update_style": self.update_style, "update_residual": self.update_residual, diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 5d785e0de9..e526e6f82b 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -152,7 +152,6 @@ def init_subclass_params(sub_data, sub_class): e_dim=self.repflow_args.e_dim, a_dim=self.repflow_args.a_dim, axis_neuron=self.repflow_args.axis_neuron, - node_has_conv=self.repflow_args.node_has_conv, update_angle=self.repflow_args.update_angle, activation_function=self.activation_function, update_style=self.repflow_args.update_style, @@ -299,7 +298,7 @@ def change_type_map( extend_descrpt_stat( repflow, type_map, - des_with_stat=model_with_new_type_stat.repflow + des_with_stat=model_with_new_type_stat.repflows if model_with_new_type_stat is not None else None, ) @@ -380,6 +379,7 @@ def serialize(self) -> dict: } repflow_variable = { "edge_embd": repflows.edge_embd.serialize(), + "angle_embd": repflows.angle_embd.serialize(), "repflow_layers": [layer.serialize() for layer in repflows.layers], "env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(), "@variables": { @@ -417,6 +417,9 @@ def t_cvt(xx): env_mat = repflow_variable.pop("env_mat") repflow_layers = repflow_variable.pop("repflow_layers") obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd")) + obj.repflows.angle_embd = MLPLayer.deserialize( + repflow_variable.pop("angle_embd") + ) obj.repflows["davg"] = t_cvt(statistic_repflows["davg"]) obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"]) obj.repflows.layers = torch.nn.ModuleList( @@ -449,12 +452,12 @@ def forward( Returns ------- - node_embd + node_ebd The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim) rot_mat The rotationally equivariant and permutationally invariant single particle representation. shape: nf x nloc x e_dim x 3 - edge_embd + edge_ebd The edge embedding. shape: nf x nloc x nnei x e_dim h2 @@ -469,23 +472,23 @@ def forward( nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 - node_embd_ext = self.type_embedding(extended_atype) - node_embd_inp = node_embd_ext[:, :nloc, :] + node_ebd_ext = self.type_embedding(extended_atype) + node_ebd_inp = node_ebd_ext[:, :nloc, :] # repflows - node_embd, edge_embd, h2, rot_mat, sw = self.repflows( + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( nlist, extended_coord, extended_atype, - node_embd_ext, + node_ebd_ext, mapping, comm_dict=comm_dict, ) if self.concat_output_tebd: - node_embd = torch.cat([node_embd, node_embd_inp], dim=-1) + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) return ( - node_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), - edge_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), ) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 2395986366..bbbdb3e20e 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -10,22 +10,20 @@ from deepmd.dpmodel.utils.seed import ( child_seed, ) -from deepmd.pt.model.network.init import ( - constant_, - normal_, +from deepmd.pt.model.descriptor.repformer_layer import ( + _apply_nlist_mask, + _apply_switch, + _make_nei_g1, + get_residual, ) from deepmd.pt.model.network.mlp import ( MLPLayer, ) -from deepmd.pt.utils import ( - env, -) from deepmd.pt.utils.env import ( PRECISION_DICT, ) from deepmd.pt.utils.utils import ( ActivationFn, - get_generator, to_numpy_array, to_torch_tensor, ) @@ -34,118 +32,6 @@ ) -def get_residual( - _dim: int, - _scale: float, - _mode: str = "norm", - trainable: bool = True, - precision: str = "float64", - seed: Optional[Union[int, list[int]]] = None, -) -> torch.Tensor: - r""" - Get residual tensor for one update vector. - - Parameters - ---------- - _dim : int - The dimension of the update vector. - _scale - The initial scale of the residual tensor. See `_mode` for details. - _mode - The mode of residual initialization for the residual tensor. - - "norm" (default): init residual using normal with `_scale` std. - - "const": init residual using element-wise constants of `_scale`. - trainable - Whether the residual tensor is trainable. - precision - The precision of the residual tensor. - seed : int, optional - Random seed for parameter initialization. - """ - random_generator = get_generator(seed) - residual = nn.Parameter( - data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), - requires_grad=trainable, - ) - if _mode == "norm": - normal_(residual.data, std=_scale, generator=random_generator) - elif _mode == "const": - constant_(residual.data, val=_scale) - else: - raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") - return residual - - -# common ops -def _make_nei_g1( - g1_ext: torch.Tensor, - nlist: torch.Tensor, -) -> torch.Tensor: - """ - Make neighbor-wise atomic invariant rep. - - Parameters - ---------- - g1_ext - Extended atomic invariant rep, with shape nb x nall x ng1. - nlist - Neighbor list, with shape nb x nloc x nnei. - - Returns - ------- - gg1: torch.Tensor - Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. - - """ - # nlist: nb x nloc x nnei - nb, nloc, nnei = nlist.shape - # g1_ext: nb x nall x ng1 - ng1 = g1_ext.shape[-1] - # index: nb x (nloc x nnei) x ng1 - index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) - # gg1 : nb x (nloc x nnei) x ng1 - gg1 = torch.gather(g1_ext, dim=1, index=index) - # gg1 : nb x nloc x nnei x ng1 - gg1 = gg1.view(nb, nloc, nnei, ng1) - return gg1 - - -def _apply_nlist_mask( - gg: torch.Tensor, - nlist_mask: torch.Tensor, -) -> torch.Tensor: - """ - Apply nlist mask to neighbor-wise rep tensors. - - Parameters - ---------- - gg - Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. - """ - # gg: nf x nloc x nnei x d - # msk: nf x nloc x nnei - return gg.masked_fill(~nlist_mask.unsqueeze(-1), 0.0) - - -def _apply_switch(gg: torch.Tensor, sw: torch.Tensor) -> torch.Tensor: - """ - Apply switch function to neighbor-wise rep tensors. - - Parameters - ---------- - gg - Neighbor-wise rep tensors, with shape nf x nloc x nnei x d. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. - """ - # gg: nf x nloc x nnei x d - # sw: nf x nloc x nnei - return gg * sw.unsqueeze(-1) - - class RepFlowLayer(torch.nn.Module): def __init__( self, @@ -161,11 +47,10 @@ def __init__( a_dim: int = 64, axis_neuron: int = 4, update_angle: bool = True, # angle - update_g1_has_conv: bool = True, activation_function: str = "silu", - update_style: str = "res_avg", - update_residual: float = 0.001, - update_residual_init: str = "norm", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", precision: str = "float64", seed: Optional[Union[int, list[int]]] = None, ) -> None: @@ -196,8 +81,6 @@ def __init__( self.seed = seed self.prec = PRECISION_DICT[precision] - self.update_g1_has_conv = update_g1_has_conv - assert update_residual_init in [ "norm", "const", @@ -205,220 +88,152 @@ def __init__( self.update_residual = update_residual self.update_residual_init = update_residual_init - self.g1_residual = [] - self.g2_residual = [] - self.h2_residual = [] + self.n_residual = [] + self.e_residual = [] self.a_residual = [] - self.proj_g1g2 = None self.edge_info_dim = self.n_dim * 2 + self.e_dim - # g1 self mlp + # node self mlp self.node_self_mlp = MLPLayer( n_dim, n_dim, precision=precision, - seed=child_seed(seed, 15), + seed=child_seed(seed, 0), ) if self.update_style == "res_residual": - self.g1_residual.append( + self.n_residual.append( get_residual( n_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 16), + seed=child_seed(seed, 1), ) ) - # g1 conv # tmp - if self.update_g1_has_conv: - self.proj_g1g2 = MLPLayer( - e_dim, - n_dim, - bias=False, - precision=precision, - seed=child_seed(seed, 4), - ) - if self.update_style == "res_residual": - self.g1_residual.append( - get_residual( - n_dim, - self.update_residual, - self.update_residual_init, - precision=precision, - seed=child_seed(seed, 17), - ) - ) - - # g1 sym - self.g1_sym_dim = self.cal_1_dim(n_dim, e_dim, self.axis_neuron) - self.linear1 = MLPLayer( - self.g1_sym_dim, + # node sym (grrg + drrd) + self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron + self.node_sym_linear = MLPLayer( + self.n_sym_dim, n_dim, precision=precision, - seed=child_seed(seed, 1), + seed=child_seed(seed, 2), ) if self.update_style == "res_residual": - self.g1_residual.append( + self.n_residual.append( get_residual( n_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 0), + seed=child_seed(seed, 3), ) ) - # g1 edge - self.g1_edge_linear1 = MLPLayer( + # node edge message + self.node_edge_linear = MLPLayer( self.edge_info_dim, n_dim, precision=precision, - seed=child_seed(seed, 11), - ) # need act + seed=child_seed(seed, 4), + ) if self.update_style == "res_residual": - self.g1_residual.append( + self.n_residual.append( get_residual( n_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 13), + seed=child_seed(seed, 5), ) ) - # g2 edge - self.linear2 = MLPLayer( + # edge self message + self.edge_self_linear = MLPLayer( self.edge_info_dim, e_dim, precision=precision, - seed=child_seed(seed, 2), + seed=child_seed(seed, 6), ) if self.update_style == "res_residual": - self.g2_residual.append( + self.e_residual.append( get_residual( e_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 3), + seed=child_seed(seed, 7), ) ) if self.update_angle: - angle_seed = 20 self.angle_dim = self.a_dim + self.n_dim + 2 * self.e_dim - self.angle_linear = MLPLayer( + + # edge angle message + self.edge_angle_linear1 = MLPLayer( self.angle_dim, - self.a_dim, + self.e_dim, precision=precision, - seed=child_seed(seed, angle_seed + 1), - ) # need act + seed=child_seed(seed, 8), + ) + self.edge_angle_linear2 = MLPLayer( + self.e_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 9), + ) if self.update_style == "res_residual": - self.a_residual.append( + self.e_residual.append( get_residual( - self.a_dim, + self.e_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, angle_seed + 2), + seed=child_seed(seed, 10), ) ) - self.g2_angle_linear1 = MLPLayer( + # angle self message + self.angle_self_linear = MLPLayer( self.angle_dim, - self.e_dim, - precision=precision, - seed=child_seed(seed, angle_seed + 3), - ) # need act - self.g2_angle_linear2 = MLPLayer( - self.e_dim, - self.e_dim, + self.a_dim, precision=precision, - seed=child_seed(seed, angle_seed + 4), + seed=child_seed(seed, 11), ) if self.update_style == "res_residual": - self.g2_residual.append( + self.a_residual.append( get_residual( - self.e_dim, + self.a_dim, self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, angle_seed + 5), + seed=child_seed(seed, 12), ) ) - else: - self.angle_linear = None - self.g2_angle_linear1 = None - self.g2_angle_linear2 = None - self.angle_dim = 0 + self.angle_self_linear = None + self.edge_angle_linear1 = None + self.edge_angle_linear2 = None self.angle_dim = 0 - self.g1_residual = nn.ParameterList(self.g1_residual) - self.g2_residual = nn.ParameterList(self.g2_residual) - self.h2_residual = nn.ParameterList(self.h2_residual) + self.n_residual = nn.ParameterList(self.n_residual) + self.e_residual = nn.ParameterList(self.e_residual) self.a_residual = nn.ParameterList(self.a_residual) - def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g2d * ax + g1d * ax - return ret - - def _update_g1_conv( - self, - gg1: torch.Tensor, - g2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - ) -> torch.Tensor: - """ - Calculate the convolution update for atomic invariant rep. - - Parameters - ---------- - gg1 - Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. - g2 - Pair invariant rep, with shape nb x nloc x nnei x ng2. - nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. - sw - The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nb x nloc x nnei. - """ - assert self.proj_g1g2 is not None - nb, nloc, nnei, _ = g2.shape - ng1 = gg1.shape[-1] - ng2 = g2.shape[-1] - gg1 = gg1.view(nb, nloc, nnei, ng1) - # nb x nloc x nnei x ng2/ng1 - gg1 = _apply_nlist_mask(gg1, nlist_mask) - gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device - ) - g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1) - # nb x nloc x ng1 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei - return g1_11 - @staticmethod def _cal_hg( - g2: torch.Tensor, + edge_ebd: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, - smooth: bool = True, - epsilon: float = 1e-4, - use_sqrt_nnei: bool = True, ) -> torch.Tensor: """ Calculate the transposed rotation matrix. Parameters ---------- - g2 - Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + edge_ebd + Neighbor-wise/Pair-wise edge embeddings, with shape nb x nloc x nnei x e_dim. h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask @@ -426,47 +241,26 @@ def _cal_hg( sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nb x nloc x nnei. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. Returns ------- hg - The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. """ - # g2: nb x nloc x nnei x ng2 + # edge_ebd: nb x nloc x nnei x e_dim # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei - nb, nloc, nnei, _ = g2.shape - ng2 = g2.shape[-1] - # nb x nloc x nnei x ng2 - g2 = _apply_nlist_mask(g2, nlist_mask) - if not smooth: - # nb x nloc - # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - if not use_sqrt_nnei: - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) - else: - invnnei = 1.0 / ( - epsilon + torch.sqrt(torch.sum(nlist_mask.type_as(g2), dim=-1)) - ) - # nb x nloc x 1 x 1 - invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) - else: - g2 = _apply_switch(g2, sw) - if not use_sqrt_nnei: - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device - ) - else: - invnnei = torch.rsqrt( - float(nnei) - * torch.ones((nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device) - ) - # nb x nloc x 3 x ng2 - h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei + nb, nloc, nnei, _ = edge_ebd.shape + e_dim = edge_ebd.shape[-1] + # nb x nloc x nnei x e_dim + edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask) + edge_ebd = _apply_switch(edge_ebd, sw) + invnnei = torch.rsqrt( + float(nnei) + * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device) + ) + # nb x nloc x 3 x e_dim + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei return h2g2 @staticmethod @@ -477,42 +271,40 @@ def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: Parameters ---------- h2g2 - The transposed rotation matrix, with shape nb x nloc x 3 x ng2. + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. axis_neuron Size of the submatrix. Returns ------- grrg - Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) """ - # nb x nloc x 3 x ng2 - nb, nloc, _, ng2 = h2g2.shape + # nb x nloc x 3 x e_dim + nb, nloc, _, e_dim = h2g2.shape # nb x nloc x 3 x axis h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] - # nb x nloc x axis x ng2 + # nb x nloc x axis x e_dim g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) # nb x nloc x (axisxng2) - g1_13 = g1_13.view(nb, nloc, axis_neuron * ng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * e_dim) return g1_13 def symmetrization_op( self, - g2: torch.Tensor, + edge_ebd: torch.Tensor, h2: torch.Tensor, nlist_mask: torch.Tensor, sw: torch.Tensor, axis_neuron: int, - smooth: bool = True, - epsilon: float = 1e-4, ) -> torch.Tensor: """ Symmetrization operator to obtain atomic invariant rep. Parameters ---------- - g2 - Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2. + edge_ebd + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x e_dim. h2 Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. nlist_mask @@ -522,29 +314,22 @@ def symmetrization_op( and remains 0 beyond rcut, with shape nb x nloc x nnei. axis_neuron Size of the submatrix. - smooth - Whether to use smoothness in processes such as attention weights calculation. - epsilon - Protection of 1./nnei. Returns ------- grrg - Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2) + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) """ - # g2: nb x nloc x nnei x ng2 + # edge_ebd: nb x nloc x nnei x e_dim # h2: nb x nloc x nnei x 3 # msk: nb x nloc x nnei - nb, nloc, nnei, _ = g2.shape - # nb x nloc x 3 x ng2 + nb, nloc, nnei, _ = edge_ebd.shape + # nb x nloc x 3 x e_dim h2g2 = self._cal_hg( - g2, + edge_ebd, h2, nlist_mask, sw, - smooth=smooth, - epsilon=epsilon, - use_sqrt_nnei=True, ) # nb x nloc x (axisxng2) g1_13 = self._cal_grrg(h2g2, axis_neuron) @@ -552,179 +337,199 @@ def symmetrization_op( def forward( self, - g1_ext: torch.Tensor, # nf x nall x ng1 - g2: torch.Tensor, # nf x nloc x nnei x ng2 + node_ebd_ext: torch.Tensor, # nf x nall x n_dim + edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim h2: torch.Tensor, # nf x nloc x nnei x 3 - angle_embed: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim + angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim nlist: torch.Tensor, # nf x nloc x nnei nlist_mask: torch.Tensor, # nf x nloc x nnei sw: torch.Tensor, # switch func, nf x nloc x nnei - angle_nlist: torch.Tensor, # nf x nloc x a_nnei - angle_nlist_mask: torch.Tensor, # nf x nloc x a_nnei - angle_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + a_nlist: torch.Tensor, # nf x nloc x a_nnei + a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei + a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei ): """ Parameters ---------- - g1_ext : nf x nall x ng1 extended single-atom channel - g2 : nf x nloc x nnei x ng2 pair-atom channel, invariant - h2 : nf x nloc x nnei x 3 pair-atom channel, equivariant - nlist : nf x nloc x nnei neighbor list (padded neis are set to 0) - nlist_mask : nf x nloc x nnei masks of the neighbor list. real nei 1 otherwise 0 - sw : nf x nloc x nnei switch function + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. Returns ------- - g1: nf x nloc x ng1 updated single-atom channel - g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant - h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. """ - nb, nloc, nnei, _ = g2.shape - nall = g1_ext.shape[1] - g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) - assert (nb, nloc) == g1.shape[:2] + nb, nloc, nnei, _ = edge_ebd.shape + nall = node_ebd_ext.shape[1] + node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == node_ebd.shape[:2] assert (nb, nloc, nnei) == h2.shape[:3] + del a_nlist # may be used in the future - g1_update: list[torch.Tensor] = [g1] - g2_update: list[torch.Tensor] = [g2] - a_update: list[torch.Tensor] = [angle_embed] - h2_update: list[torch.Tensor] = [h2] - - g1_sym: list[torch.Tensor] = [] + n_update_list: list[torch.Tensor] = [node_ebd] + e_update_list: list[torch.Tensor] = [edge_ebd] + a_update_list: list[torch.Tensor] = [angle_ebd] - # g1 self mlp - node_self_mlp = self.act(self.node_self_mlp(g1)) - g1_update.append(node_self_mlp) + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) - gg1 = _make_nei_g1(g1_ext, nlist) - # g1 conv # tmp - if self.update_g1_has_conv: - assert gg1 is not None - g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) - g1_update.append(g1_conv) + nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) - # g1 sym mlp - g1_sym.append( + # node sym (grrg + drrd) + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( self.symmetrization_op( - g2, + edge_ebd, h2, nlist_mask, sw, self.axis_neuron, - smooth=True, - epsilon=self.epsilon, ) ) - g1_sym.append( + node_sym_list.append( self.symmetrization_op( - gg1, + nei_node_ebd, h2, nlist_mask, sw, self.axis_neuron, - smooth=True, - epsilon=self.epsilon, ) ) - g1_1 = self.act(self.linear1(torch.cat(g1_sym, dim=-1))) - g1_update.append(g1_1) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + # nb x nloc x nnei x (n_dim * 2 + e_dim) edge_info = torch.cat( - [torch.tile(g1.unsqueeze(-2), [1, 1, self.nnei, 1]), gg1, g2], dim=-1 + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, ) - # g1 edge update - # nb x nloc x nnei x ng1 - g1_edge_info = self.act(self.g1_edge_linear1(edge_info)) * sw.unsqueeze(-1) - g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei - g1_update.append(g1_edge_update) - # update g1 - g1_new = self.list_update(g1_update, "g1") + # node edge message + # nb x nloc x nnei x n_dim + node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1) + node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") - # g2 edge update - g2_edge_info = self.act(self.linear2(edge_info)) - g2_update.append(g2_edge_info) + # edge self message + edge_self_update = self.act(self.edge_self_linear(edge_info)) + e_update_list.append(edge_self_update) if self.update_angle: - assert self.angle_linear is not None - assert self.g2_angle_linear1 is not None - assert self.g2_angle_linear2 is not None - # nb x nloc x a_nnei x a_nnei x g1 - g1_angle_embed = torch.tile( - g1.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1) + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + # get angle info + # nb x nloc x a_nnei x a_nnei x n_dim + node_for_angle_info = torch.tile( + node_ebd.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1) + ) + # nb x nloc x a_nnei x e_dim + edge_for_angle = edge_ebd[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 + ) + # nb x nloc x (a_nnei) x a_nnei x edge_ebd + edge_for_angle_i = torch.tile( + edge_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + ) + # nb x nloc x a_nnei x (a_nnei) x e_dim + edge_for_angle_j = torch.tile( + edge_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) ) - # nb x nloc x a_nnei x g2 - g2_angle = g2[:, :, : self.a_sel, :] - # nb x nloc x a_nnei x g2 - g2_angle = torch.where(angle_nlist_mask.unsqueeze(-1), g2_angle, 0.0) - # nb x nloc x (a_nnei) x a_nnei x g2 - g2_angle_i = torch.tile(g2_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) - # nb x nloc x a_nnei x (a_nnei) x g2 - g2_angle_j = torch.tile(g2_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) - # nb x nloc x a_nnei x a_nnei x (g2 + g2) - g2_angle_embed = torch.cat([g2_angle_i, g2_angle_j], dim=-1) - - # angle for g2: - updated_g2_angle_list = [angle_embed] - # nb x nloc x a_nnei x a_nnei x (a + g1 + g2*2) - updated_g2_angle_list += [g1_angle_embed, g2_angle_embed] - updated_g2_angle = torch.cat(updated_g2_angle_list, dim=-1) - # nb x nloc x a_nnei x a_nnei x g2 - updated_angle_g2 = self.act(self.g2_angle_linear1(updated_g2_angle)) - # nb x nloc x a_nnei x a_nnei x g2 - weighted_updated_angle_g2 = ( - updated_angle_g2 - * angle_sw[:, :, :, None, None] - * angle_sw[:, :, None, :, None] + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + edge_for_angle_info = torch.cat( + [edge_for_angle_i, edge_for_angle_j], dim=-1 ) - # nb x nloc x a_nnei x g2 - reduced_updated_angle_g2 = torch.sum(weighted_updated_angle_g2, dim=-2) / ( - self.a_sel**0.5 + angle_info_list = [angle_ebd, node_for_angle_info, edge_for_angle_info] + # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) + angle_info = torch.cat(angle_info_list, dim=-1) + + # edge angle message + # nb x nloc x a_nnei x a_nnei x e_dim + edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + edge_angle_update + * a_sw[:, :, :, None, None] + * a_sw[:, :, None, :, None] ) - # nb x nloc x nnei x g2 - padding_updated_angle_g2 = torch.concat( + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = torch.concat( [ - reduced_updated_angle_g2, + reduced_edge_angle_update, torch.zeros( [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=g2.dtype, - device=g2.device, + dtype=edge_ebd.dtype, + device=edge_ebd.device, ), ], dim=2, ) full_mask = torch.concat( [ - angle_nlist_mask, + a_nlist_mask, torch.zeros( [nb, nloc, self.nnei - self.a_sel], - dtype=angle_nlist_mask.dtype, - device=angle_nlist_mask.device, + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, ), ], dim=-1, ) - padding_updated_angle_g2 = torch.where( - full_mask.unsqueeze(-1), padding_updated_angle_g2, g2 + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd ) - g2_update.append(self.act(self.g2_angle_linear2(padding_updated_angle_g2))) + e_update_list.append( + self.act(self.edge_angle_linear2(padding_edge_angle_update)) + ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") - # update g2 - g2_new = self.list_update(g2_update, "g2") - # angle for angle - updated_angle = updated_g2_angle + # angle self message # nb x nloc x a_nnei x a_nnei x dim_a - angle_message = self.act(self.angle_linear(updated_angle)) - # angle update - a_update.append(angle_message) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + a_update_list.append(angle_self_update) else: - # update g2 - g2_new = self.list_update(g2_update, "g2") + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") - # update - h2_new = self.list_update(h2_update, "h2") - a_new = self.list_update(a_update, "a") - return g1_new, g2_new, h2_new, a_new + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated @torch.jit.export def list_update_res_avg( @@ -748,21 +553,18 @@ def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: @torch.jit.export def list_update_res_residual( - self, update_list: list[torch.Tensor], update_name: str = "g1" + self, update_list: list[torch.Tensor], update_name: str = "node" ) -> torch.Tensor: nitem = len(update_list) uu = update_list[0] # make jit happy - if update_name == "g1": - for ii, vv in enumerate(self.g1_residual): - uu = uu + vv * update_list[ii + 1] - elif update_name == "g2": - for ii, vv in enumerate(self.g2_residual): + if update_name == "node": + for ii, vv in enumerate(self.n_residual): uu = uu + vv * update_list[ii + 1] - elif update_name == "h2": - for ii, vv in enumerate(self.h2_residual): + elif update_name == "edge": + for ii, vv in enumerate(self.e_residual): uu = uu + vv * update_list[ii + 1] - elif update_name == "a": + elif update_name == "angle": for ii, vv in enumerate(self.a_residual): uu = uu + vv * update_list[ii + 1] else: @@ -771,7 +573,7 @@ def list_update_res_residual( @torch.jit.export def list_update( - self, update_list: list[torch.Tensor], update_name: str = "g1" + self, update_list: list[torch.Tensor], update_name: str = "node" ) -> torch.Tensor: if self.update_style == "res_avg": return self.list_update_res_avg(update_list) @@ -796,61 +598,40 @@ def serialize(self) -> dict: "e_rcut": self.e_rcut, "e_rcut_smth": self.e_rcut_smth, "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, "ntypes": self.ntypes, "n_dim": self.n_dim, "e_dim": self.e_dim, + "a_dim": self.a_dim, "axis_neuron": self.axis_neuron, "activation_function": self.activation_function, + "update_angle": self.update_angle, "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, "precision": self.precision, - "linear1": self.linear1.serialize(), + "node_self_mlp": self.node_self_mlp.serialize(), + "node_sym_linear": self.node_sym_linear.serialize(), + "node_edge_linear": self.node_edge_linear.serialize(), + "edge_self_linear": self.edge_self_linear.serialize(), } - if self.update_g1_has_conv: - data.update( - { - "proj_g1g2": self.proj_g1g2.serialize(), - } - ) - - if self.update_g2_has_attn or self.update_h2: - data.update( - { - "attn2g_map": self.attn2g_map.serialize(), - } - ) - if self.update_g2_has_attn: - data.update( - { - "attn2_mh_apply": self.attn2_mh_apply.serialize(), - "attn2_lm": self.attn2_lm.serialize(), - } - ) - - if self.update_h2: - data.update( - { - "attn2_ev_apply": self.attn2_ev_apply.serialize(), - } - ) - if self.update_g1_has_attn: - data.update( - { - "loc_attn": self.loc_attn.serialize(), - } - ) - if self.g1_out_mlp: + if self.update_angle: data.update( { - "node_self_mlp": self.node_self_mlp.serialize(), + "edge_angle_linear1": self.edge_angle_linear1.serialize(), + "edge_angle_linear2": self.edge_angle_linear2.serialize(), + "angle_self_linear": self.angle_self_linear.serialize(), } ) if self.update_style == "res_residual": data.update( { "@variables": { - "g1_residual": [to_numpy_array(t) for t in self.g1_residual], - "g2_residual": [to_numpy_array(t) for t in self.g2_residual], - "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + "n_residual": [to_numpy_array(t) for t in self.n_residual], + "e_residual": [to_numpy_array(t) for t in self.e_residual], + "a_residual": [to_numpy_array(t) for t in self.a_residual], } } ) @@ -866,48 +647,41 @@ def deserialize(cls, data: dict) -> "RepFlowLayer": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version"), 2, 1) + check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") - linear1 = data.pop("linear1") - update_chnnl_2 = data["update_chnnl_2"] - update_g1_has_conv = data["update_g1_has_conv"] - update_g2_has_g1g1 = data["update_g2_has_g1g1"] - update_g2_has_attn = data["update_g2_has_attn"] - update_h2 = data["update_h2"] - update_g1_has_attn = data["update_g1_has_attn"] + update_angle = data["update_angle"] + node_self_mlp = data.pop("node_self_mlp") + node_sym_linear = data.pop("node_sym_linear") + node_edge_linear = data.pop("node_edge_linear") + edge_self_linear = data.pop("edge_self_linear") + edge_angle_linear1 = data.pop("edge_angle_linear1", None) + edge_angle_linear2 = data.pop("edge_angle_linear2", None) + angle_self_linear = data.pop("angle_self_linear", None) update_style = data["update_style"] - g1_out_mlp = data["g1_out_mlp"] - - linear2 = data.pop("linear2", None) - proj_g1g2 = data.pop("proj_g1g2", None) - attn2g_map = data.pop("attn2g_map", None) - attn2_mh_apply = data.pop("attn2_mh_apply", None) - attn2_lm = data.pop("attn2_lm", None) - attn2_ev_apply = data.pop("attn2_ev_apply", None) - loc_attn = data.pop("loc_attn", None) - node_self_mlp = data.pop("node_self_mlp", None) variables = data.pop("@variables", {}) - g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) - g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) - h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) + n_residual = variables.get("n_residual", data.pop("n_residual", [])) + e_residual = variables.get("e_residual", data.pop("e_residual", [])) + a_residual = variables.get("a_residual", data.pop("a_residual", [])) obj = cls(**data) - obj.linear1 = MLPLayer.deserialize(linear1) - if update_chnnl_2: - assert isinstance(linear2, dict) - obj.linear2 = MLPLayer.deserialize(linear2) - if update_g1_has_conv: - assert isinstance(proj_g1g2, dict) - obj.proj_g1g2 = MLPLayer.deserialize(proj_g1g2) - - if g1_out_mlp: - assert isinstance(node_self_mlp, dict) - obj.node_self_mlp = MLPLayer.deserialize(node_self_mlp) + obj.node_self_mlp = MLPLayer.deserialize(node_self_mlp) + obj.node_sym_linear = MLPLayer.deserialize(node_sym_linear) + obj.node_edge_linear = MLPLayer.deserialize(node_edge_linear) + obj.edge_self_linear = MLPLayer.deserialize(edge_self_linear) + + if update_angle: + assert isinstance(edge_angle_linear1, dict) + assert isinstance(edge_angle_linear2, dict) + assert isinstance(angle_self_linear, dict) + obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1) + obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2) + obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear) + if update_style == "res_residual": - for ii, t in enumerate(obj.g1_residual): - t.data = to_torch_tensor(g1_residual[ii]) - for ii, t in enumerate(obj.g2_residual): - t.data = to_torch_tensor(g2_residual[ii]) - for ii, t in enumerate(obj.h2_residual): - t.data = to_torch_tensor(h2_residual[ii]) + for ii, t in enumerate(obj.n_residual): + t.data = to_torch_tensor(n_residual[ii]) + for ii, t in enumerate(obj.e_residual): + t.data = to_torch_tensor(e_residual[ii]) + for ii, t in enumerate(obj.a_residual): + t.data = to_torch_tensor(a_residual[ii]) return obj diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index c5a39f8b3b..302f020754 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -86,7 +86,6 @@ def __init__( e_dim: int = 64, a_dim: int = 64, axis_neuron: int = 4, - node_has_conv: bool = False, update_angle: bool = True, activation_function: str = "silu", update_style: str = "res_residual", @@ -182,7 +181,6 @@ def __init__( self.e_dim = e_dim self.a_dim = a_dim self.update_angle = update_angle - self.node_has_conv = node_has_conv self.activation_function = activation_function self.update_style = update_style @@ -190,12 +188,6 @@ def __init__( self.update_residual_init = update_residual_init self.act = ActivationFn(activation_function) self.prec = PRECISION_DICT[precision] - self.angle_embedding = torch.nn.Linear( - in_features=1, - out_features=self.a_dim, - bias=False, - dtype=self.prec, - ) # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) @@ -207,6 +199,9 @@ def __init__( self.edge_embd = MLPLayer( 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) ) + self.angle_embd = MLPLayer( + 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + ) layers = [] for ii in range(nlayers): layers.append( @@ -222,7 +217,6 @@ def __init__( e_dim=self.e_dim, a_dim=self.a_dim, axis_neuron=self.axis_neuron, - update_g1_has_conv=self.node_has_conv, # tmp update_angle=self.update_angle, activation_function=self.activation_function, update_style=self.update_style, @@ -270,7 +264,7 @@ def get_dim_in(self) -> int: return self.dim_in def get_dim_emb(self) -> int: - """Returns the embedding dimension g2.""" + """Returns the embedding dimension e_dim.""" return self.e_dim def __setitem__(self, key, value) -> None: @@ -317,7 +311,7 @@ def dim_in(self): @property def dim_emb(self): - """Returns the embedding dimension g2.""" + """Returns the embedding dimension e_dim.""" return self.get_dim_emb() def reinit_exclude( @@ -369,22 +363,22 @@ def forward( else: atype_embd = extended_atype_embd assert isinstance(atype_embd, torch.Tensor) # for jit - g1 = self.act(atype_embd) - ng1 = g1.shape[-1] + node_ebd = self.act(atype_embd) + n_dim = node_ebd.shape[-1] # nb x nloc x nnei x 1, nb x nloc x nnei x 3 - g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) - # nb x nloc x nnei x ng2 - g2 = self.act(self.edge_embd(g2)) + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + # nb x nloc x nnei x e_dim + edge_ebd = self.act(self.edge_embd(edge_input)) # get angle nlist (maybe smaller) a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ :, :, : self.a_sel ] - angle_nlist = nlist[:, :, : self.a_sel] - angle_nlist = torch.where(a_dist_mask, angle_nlist, -1) - _, angle_diff, angle_sw = prod_env_mat( + a_nlist = nlist[:, :, : self.a_sel] + a_nlist = torch.where(a_dist_mask, a_nlist, -1) + _, a_diff, a_sw = prod_env_mat( extended_coord, - angle_nlist, + a_nlist, atype, self.mean[:, : self.a_sel], self.stddev[:, : self.a_sel], @@ -392,15 +386,15 @@ def forward( self.a_rcut_smth, protection=self.env_protection, ) - angle_nlist_mask = angle_nlist != -1 - angle_sw = torch.squeeze(angle_sw, -1) + a_nlist_mask = a_nlist != -1 + a_sw = torch.squeeze(a_sw, -1) # beyond the cutoff sw should be 0.0 - angle_sw = angle_sw.masked_fill(~angle_nlist_mask, 0.0) - angle_nlist[angle_nlist == -1] = 0 + a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + a_nlist[a_nlist == -1] = 0 # nf x nloc x a_nnei x 3 - normalized_diff_i = angle_diff / ( - torch.linalg.norm(angle_diff, dim=-1, keepdim=True) + 1e-6 + normalized_diff_i = a_diff / ( + torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 ) # nf x nloc x 3 x a_nnei normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) @@ -410,31 +404,31 @@ def forward( # nf x nloc x a_nnei x a_nnei x 1 cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) # nf x nloc x a_nnei x a_nnei x a_dim - angle_embed = self.angle_embedding(cosine_ij).reshape( + angle_ebd = self.angle_embd(cosine_ij).reshape( nframes, nloc, self.a_sel, self.a_sel, self.a_dim ) # set all padding positions to index of 0 # if the a neighbor is real or not is indicated by nlist_mask nlist[nlist == -1] = 0 - # nb x nall x ng1 + # nb x nall x n_dim if comm_dict is None: assert mapping is not None mapping = ( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) ) for idx, ll in enumerate(self.layers): - # g1: nb x nloc x ng1 - # g1_ext: nb x nall x ng1 + # node_ebd: nb x nloc x n_dim + # node_ebd_ext: nb x nall x n_dim if comm_dict is None: assert mapping is not None - g1_ext = torch.gather(g1, 1, mapping) + node_ebd_ext = torch.gather(node_ebd, 1, mapping) else: has_spin = "has_spin" in comm_dict if not has_spin: n_padding = nall - nloc - g1 = torch.nn.functional.pad( - g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 + node_ebd = torch.nn.functional.pad( + node_ebd.squeeze(0), (0, 0, 0, n_padding), value=0.0 ) real_nloc = nloc real_nall = nall @@ -443,12 +437,14 @@ def forward( real_nloc = nloc // 2 real_nall = nall // 2 real_n_padding = real_nall - real_nloc - g1_real, g1_virtual = torch.split(g1, [real_nloc, real_nloc], dim=1) - # mix_g1: nb x real_nloc x (ng1 * 2) - mix_g1 = torch.cat([g1_real, g1_virtual], dim=2) - # nb x real_nall x (ng1 * 2) - g1 = torch.nn.functional.pad( - mix_g1.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + node_ebd_real, node_ebd_virtual = torch.split( + node_ebd, [real_nloc, real_nloc], dim=1 + ) + # mix_node_ebd: nb x real_nloc x (n_dim * 2) + mix_node_ebd = torch.cat([node_ebd_real, node_ebd_virtual], dim=2) + # nb x real_nall x (n_dim * 2) + node_ebd = torch.nn.functional.pad( + mix_node_ebd.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 ) assert "send_list" in comm_dict @@ -463,7 +459,7 @@ def forward( comm_dict["recv_proc"], comm_dict["send_num"], comm_dict["recv_num"], - g1, + node_ebd, comm_dict["communicator"], torch.tensor( real_nloc, @@ -476,39 +472,33 @@ def forward( device=env.DEVICE, ), # should be int of c++ ) - g1_ext = ret[0].unsqueeze(0) + node_ebd_ext = ret[0].unsqueeze(0) if has_spin: - g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2) - g1_ext = concat_switch_virtual( - g1_real_ext, g1_virtual_ext, real_nloc + node_ebd_real_ext, node_ebd_virtual_ext = torch.split( + node_ebd_ext, [n_dim, n_dim], dim=2 ) - g1, g2, h2, angle_embed = ll.forward( - g1_ext, - g2, + node_ebd_ext = concat_switch_virtual( + node_ebd_real_ext, node_ebd_virtual_ext, real_nloc + ) + node_ebd, edge_ebd, angle_ebd = ll.forward( + node_ebd_ext, + edge_ebd, h2, - angle_embed, + angle_ebd, nlist, nlist_mask, sw, - angle_nlist, - angle_nlist_mask, - angle_sw, + a_nlist, + a_nlist_mask, + a_sw, ) - # nb x nloc x 3 x ng2 - h2g2 = RepFlowLayer._cal_hg( - g2, - h2, - nlist_mask, - sw, - smooth=True, - epsilon=self.epsilon, - use_sqrt_nnei=True, - ) - # (nb x nloc) x ng2 x 3 + # nb x nloc x 3 x e_dim + h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + # (nb x nloc) x e_dim x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw def compute_input_stats( self, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0ac084faf4..237bd1003d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1482,7 +1482,6 @@ def dpa3_repflow_args(): default=4, doc=doc_axis_neuron, ), - Argument("node_has_conv", bool, optional=True, default=False, doc="TMP"), Argument( "update_angle", bool, diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py new file mode 100644 index 0000000000..300ff44bdd --- /dev/null +++ b/source/tests/pt/model/test_dpa3.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA3(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + prec, + ect, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA3.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + def test_jit( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + prec, + ect, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + model = torch.jit.script(dd0) diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 7911cb9395..1d55ee085b 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -17,6 +17,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) from ....consistent.common import ( parameterize_func, @@ -460,6 +463,72 @@ def DescriptorParamDPA2( DescriptorParamDPA2 = DescriptorParamDPA2List[0] +def DescriptorParamDPA3( + ntypes, + rcut, + rcut_smth, + sel, + type_map, + env_protection=0.0, + exclude_types=[], + update_style="res_residual", + update_residual=0.1, + update_residual_init="const", + update_angle=True, + precision="float64", +): + input_dict = { + # kwargs for repformer + "repflow": RepFlowArgs( + **{ + "n_dim": 20, + "e_dim": 10, + "a_dim": 10, + "nlayers": 3, + "e_rcut": rcut, + "e_rcut_smth": rcut_smth, + "e_sel": sum(sel), + "a_rcut": rcut / 2, + "a_rcut_smth": rcut_smth / 2, + "a_sel": sum(sel) // 4, + "axis_neuron": 4, + "update_angle": update_angle, + "update_style": update_style, + "update_residual": update_residual, + "update_residual_init": update_residual_init, + } + ), + "ntypes": ntypes, + "concat_output_tebd": False, + "precision": precision, + "activation_function": "silu", + "exclude_types": exclude_types, + "env_protection": env_protection, + "trainable": True, + "use_econf_tebd": False, + "use_tebd_bias": False, + "type_map": type_map, + "seed": GLOBAL_SEED, + } + return input_dict + + +DescriptorParamDPA3List = parameterize_func( + DescriptorParamDPA3, + OrderedDict( + { + "update_residual_init": ("const",), + "exclude_types": ([], [[0, 1]]), + "update_angle": (True, False), + "env_protection": (0.0, 1e-8), + "precision": ("float64",), + } + ), +) +# to get name for the default function +DescriptorParamDPA3 = DescriptorParamDPA3List[0] + + def DescriptorParamHybrid(ntypes, rcut, rcut_smth, sel, type_map, **kwargs): ddsub0 = { "type": "se_e2_a", diff --git a/source/tests/universal/pt/descriptor/test_descriptor.py b/source/tests/universal/pt/descriptor/test_descriptor.py index 349eb65588..25c78b43c1 100644 --- a/source/tests/universal/pt/descriptor/test_descriptor.py +++ b/source/tests/universal/pt/descriptor/test_descriptor.py @@ -4,6 +4,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -20,6 +21,7 @@ from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamSeA, @@ -40,6 +42,7 @@ (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ) # class_param & class diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 3eb1484c45..b86c5ecc40 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -10,6 +10,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -55,6 +56,8 @@ DescriptorParamDPA1List, DescriptorParamDPA2, DescriptorParamDPA2List, + DescriptorParamDPA3, + DescriptorParamDPA3List, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamHybridMixedTTebd, @@ -93,6 +96,7 @@ DescriptorParamSeTTebd, DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, ] @@ -219,6 +223,7 @@ def setUpClass(cls) -> None: ], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), @@ -233,6 +238,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], @@ -316,6 +322,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -326,6 +333,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], @@ -409,6 +417,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -419,6 +428,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], @@ -721,6 +731,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -731,6 +742,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[ @@ -812,6 +824,7 @@ def setUpClass(cls) -> None: ( *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class @@ -821,6 +834,7 @@ def setUpClass(cls) -> None: ( (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], From fe6a92eb9ab8726aaa3feaf7e026541efafb1353 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 19 Dec 2024 00:25:15 +0800 Subject: [PATCH 3/5] Update dpa3.py --- deepmd/pt/model/descriptor/dpa3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index e526e6f82b..c7141e376d 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -269,7 +269,7 @@ def share_params(self, base_class, shared_level, resume=False) -> None: # share all parameters in type_embedding, repflow if shared_level == 0: self._modules["type_embedding"] = base_class._modules["type_embedding"] - self.repflows.share_params(base_class.repflow, 0, resume=resume) + self.repflows.share_params(base_class.repflows, 0, resume=resume) # shared_level: 1 # share all parameters in type_embedding elif shared_level == 1: From 20a60c611e0ee6acb95a1adf26b10735288b5adc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 11 Dec 2024 22:55:44 +0800 Subject: [PATCH 4/5] add mae --- deepmd/pt/loss/ener.py | 71 ++++++++++++++++++++++------------------ deepmd/utils/argcheck.py | 6 ++++ 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 327d75c2cd..b66f4a5b09 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms + energy_pred = energy_pred * atom_norm + energy_label = energy_label * atom_norm l1_ener_loss = F.l1_loss( energy_pred.reshape(-1), energy_label.reshape(-1), - reduction="sum", + reduction="mean", ) loss += pref_e * l1_ener_loss more_loss["mae_e"] = self.display_if_exist( - F.l1_loss( - energy_pred.reshape(-1), - energy_label.reshape(-1), - reduction="mean", - ).detach(), + l1_ener_loss.detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') - if mae: - mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm - more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) - more_loss["mae_e_all"] = self.display_if_exist( - mae_e_all.detach(), find_energy - ) + # if mae: + # mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm + # more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) + # mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) + # more_loss["mae_e_all"] = self.display_if_exist( + # mae_e_all.detach(), find_energy + # ) if ( (self.has_f or self.has_pf or self.relative_f or self.has_gf) @@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): rmse_f.detach(), find_force ) else: - l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none") + l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean") more_loss["mae_f"] = self.display_if_exist( - l1_force_loss.mean().detach(), find_force + l1_force_loss.detach(), find_force ) - l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() + # l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) - if mae: - mae_f = torch.mean(torch.abs(diff_f)) - more_loss["mae_f"] = self.display_if_exist( - mae_f.detach(), find_force - ) + # if mae: + # mae_f = torch.mean(torch.abs(diff_f)) + # more_loss["mae_f"] = self.display_if_exist( + # mae_f.detach(), find_force + # ) if self.has_pf and "atom_pref" in label: atom_pref = label["atom_pref"] @@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if self.has_v and "virial" in model_pred and "virial" in label: find_virial = label.get("find_virial", 0.0) pref_v = pref_v * find_virial + virial_label = label["virial"] + virial_pred = model_pred["virial"].reshape(-1, 9) diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) - l2_virial_loss = torch.mean(torch.square(diff_v)) - if not self.inference: - more_loss["l2_virial_loss"] = self.display_if_exist( - l2_virial_loss.detach(), find_virial + if not self.use_l1_all: + l2_virial_loss = torch.mean(torch.square(diff_v)) + if not self.inference: + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) + loss += atom_norm * (pref_v * l2_virial_loss) + rmse_v = l2_virial_loss.sqrt() * atom_norm + more_loss["rmse_v"] = self.display_if_exist( + rmse_v.detach(), find_virial + ) + else: + l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean") + more_loss["mae_v"] = self.display_if_exist( + l1_virial_loss.detach(), find_virial ) - loss += atom_norm * (pref_v * l2_virial_loss) - rmse_v = l2_virial_loss.sqrt() * atom_norm - more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) - if mae: - mae_v = torch.mean(torch.abs(diff_v)) * atom_norm - more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) + loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION) + # if mae: + # mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + # more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: atom_ener = model_pred["atom_energy"] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 237bd1003d..b8c1403f2b 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2346,6 +2346,12 @@ def loss_ener(): doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label." doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1." return [ + Argument( + "use_l1_all", + bool, + optional=True, + default=False, + ), Argument( "start_pref_e", [float, int], From 1309e26f134c26dd56942417eaca7a5296e25494 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 23 Dec 2024 22:18:55 +0800 Subject: [PATCH 5/5] add compress --- deepmd/dpmodel/descriptor/dpa3.py | 7 ++ deepmd/pt/model/descriptor/dpa3.py | 1 + deepmd/pt/model/descriptor/repflow_layer.py | 75 ++++++++++++++++--- deepmd/pt/model/descriptor/repflows.py | 7 ++ deepmd/utils/argcheck.py | 8 ++ source/tests/pt/model/test_dpa3.py | 6 ++ .../dpmodel/descriptor/test_descriptor.py | 3 + 7 files changed, 98 insertions(+), 9 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index df1a2ae258..e1e8632b0e 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -14,6 +14,7 @@ def __init__( a_rcut: float = 4.0, a_rcut_smth: float = 3.5, a_sel: int = 20, + a_compress_rate: int = 0, axis_neuron: int = 4, update_angle: bool = True, update_style: str = "res_residual", @@ -44,6 +45,10 @@ def __init__( Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. a_sel : int, optional Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to n_dim/c and e_dim/2c, respectively, within the angular message. axis_neuron : int, optional The number of dimension of submatrix in the symmetrization ops. update_angle : bool, optional @@ -71,6 +76,7 @@ def __init__( self.a_rcut = a_rcut self.a_rcut_smth = a_rcut_smth self.a_sel = a_sel + self.a_compress_rate = a_compress_rate self.axis_neuron = axis_neuron self.update_angle = update_angle self.update_style = update_style @@ -95,6 +101,7 @@ def serialize(self) -> dict: "a_rcut": self.a_rcut, "a_rcut_smth": self.a_rcut_smth, "a_sel": self.a_sel, + "a_compress_rate": self.a_compress_rate, "axis_neuron": self.axis_neuron, "update_angle": self.update_angle, "update_style": self.update_style, diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index c7141e376d..c5cfd9cb89 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -151,6 +151,7 @@ def init_subclass_params(sub_data, sub_class): n_dim=self.repflow_args.n_dim, e_dim=self.repflow_args.e_dim, a_dim=self.repflow_args.a_dim, + a_compress_rate=self.repflow_args.a_compress_rate, axis_neuron=self.repflow_args.axis_neuron, update_angle=self.repflow_args.update_angle, activation_function=self.activation_function, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index bbbdb3e20e..94c4945c76 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -45,6 +45,7 @@ def __init__( n_dim: int = 128, e_dim: int = 16, a_dim: int = 64, + a_compress_rate: int = 0, axis_neuron: int = 4, update_angle: bool = True, # angle activation_function: str = "silu", @@ -70,6 +71,12 @@ def __init__( self.n_dim = n_dim self.e_dim = e_dim self.a_dim = a_dim + self.a_compress_rate = a_compress_rate + if a_compress_rate != 0: + assert a_dim % (2 * a_compress_rate) == 0, ( + f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. " + f"Currently, a_dim={a_dim} is not valid." + ) self.axis_neuron = axis_neuron self.update_angle = update_angle self.activation_function = activation_function @@ -167,20 +174,42 @@ def __init__( ) if self.update_angle: - self.angle_dim = self.a_dim + self.n_dim + 2 * self.e_dim + self.angle_dim = self.a_dim + if self.a_compress_rate == 0: + # angle + node + edge * 2 + self.angle_dim += self.n_dim + 2 * self.e_dim + self.a_compress_n_linear = None + self.a_compress_e_linear = None + else: + # angle + node/c + edge/2c * 2 + self.angle_dim += 2 * (self.a_dim // self.a_compress_rate) + self.a_compress_n_linear = MLPLayer( + self.n_dim, + self.a_dim // self.a_compress_rate, + precision=precision, + bias=False, + seed=child_seed(seed, 8), + ) + self.a_compress_e_linear = MLPLayer( + self.e_dim, + self.a_dim // (2 * self.a_compress_rate), + precision=precision, + bias=False, + seed=child_seed(seed, 9), + ) # edge angle message self.edge_angle_linear1 = MLPLayer( self.angle_dim, self.e_dim, precision=precision, - seed=child_seed(seed, 8), + seed=child_seed(seed, 10), ) self.edge_angle_linear2 = MLPLayer( self.e_dim, self.e_dim, precision=precision, - seed=child_seed(seed, 9), + seed=child_seed(seed, 11), ) if self.update_style == "res_residual": self.e_residual.append( @@ -189,7 +218,7 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 10), + seed=child_seed(seed, 12), ) ) @@ -198,7 +227,7 @@ def __init__( self.angle_dim, self.a_dim, precision=precision, - seed=child_seed(seed, 11), + seed=child_seed(seed, 13), ) if self.update_style == "res_residual": self.a_residual.append( @@ -207,13 +236,15 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, - seed=child_seed(seed, 12), + seed=child_seed(seed, 14), ) ) else: self.angle_self_linear = None self.edge_angle_linear1 = None self.edge_angle_linear2 = None + self.a_compress_n_linear = None + self.a_compress_e_linear = None self.angle_dim = 0 self.n_residual = nn.ParameterList(self.n_residual) @@ -448,12 +479,22 @@ def forward( assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None # get angle info + if self.a_compress_rate != 0: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + # nb x nloc x a_nnei x a_nnei x n_dim node_for_angle_info = torch.tile( - node_ebd.unsqueeze(2).unsqueeze(2), (1, 1, self.a_sel, self.a_sel, 1) + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), ) # nb x nloc x a_nnei x e_dim - edge_for_angle = edge_ebd[:, :, : self.a_sel, :] + edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] # nb x nloc x a_nnei x e_dim edge_for_angle = torch.where( a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 @@ -471,7 +512,7 @@ def forward( [edge_for_angle_i, edge_for_angle_j], dim=-1 ) angle_info_list = [angle_ebd, node_for_angle_info, edge_for_angle_info] - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) + # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) angle_info = torch.cat(angle_info_list, dim=-1) # edge angle message @@ -605,6 +646,7 @@ def serialize(self) -> dict: "n_dim": self.n_dim, "e_dim": self.e_dim, "a_dim": self.a_dim, + "a_compress_rate": self.a_compress_rate, "axis_neuron": self.axis_neuron, "activation_function": self.activation_function, "update_angle": self.update_angle, @@ -625,6 +667,13 @@ def serialize(self) -> dict: "angle_self_linear": self.angle_self_linear.serialize(), } ) + if self.a_compress_rate != 0: + data.update( + { + "a_compress_n_linear": self.a_compress_n_linear.serialize(), + "a_compress_e_linear": self.a_compress_e_linear.serialize(), + } + ) if self.update_style == "res_residual": data.update( { @@ -650,6 +699,7 @@ def deserialize(cls, data: dict) -> "RepFlowLayer": check_version_compatibility(data.pop("@version"), 1, 1) data.pop("@class") update_angle = data["update_angle"] + a_compress_rate = data["a_compress_rate"] node_self_mlp = data.pop("node_self_mlp") node_sym_linear = data.pop("node_sym_linear") node_edge_linear = data.pop("node_edge_linear") @@ -657,6 +707,8 @@ def deserialize(cls, data: dict) -> "RepFlowLayer": edge_angle_linear1 = data.pop("edge_angle_linear1", None) edge_angle_linear2 = data.pop("edge_angle_linear2", None) angle_self_linear = data.pop("angle_self_linear", None) + a_compress_n_linear = data.pop("a_compress_n_linear", None) + a_compress_e_linear = data.pop("a_compress_e_linear", None) update_style = data["update_style"] variables = data.pop("@variables", {}) n_residual = variables.get("n_residual", data.pop("n_residual", [])) @@ -676,6 +728,11 @@ def deserialize(cls, data: dict) -> "RepFlowLayer": obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1) obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2) obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear) + if a_compress_rate != 0: + assert isinstance(a_compress_n_linear, dict) + assert isinstance(a_compress_e_linear, dict) + obj.a_compress_n_linear = MLPLayer.deserialize(a_compress_n_linear) + obj.a_compress_e_linear = MLPLayer.deserialize(a_compress_e_linear) if update_style == "res_residual": for ii, t in enumerate(obj.n_residual): diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 302f020754..e8ad3a78e0 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -85,6 +85,7 @@ def __init__( n_dim: int = 128, e_dim: int = 64, a_dim: int = 64, + a_compress_rate: int = 0, axis_neuron: int = 4, update_angle: bool = True, activation_function: str = "silu", @@ -122,6 +123,10 @@ def __init__( Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. a_sel : int, optional Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to n_dim/c and e_dim/2c, respectively, within the angular message. axis_neuron : int, optional The number of dimension of submatrix in the symmetrization ops. update_angle : bool, optional @@ -174,6 +179,7 @@ def __init__( self.rcut_smth = e_rcut_smth self.sec = self.sel self.split_sel = self.sel + self.a_compress_rate = a_compress_rate self.axis_neuron = axis_neuron self.set_davg_zero = set_davg_zero @@ -216,6 +222,7 @@ def __init__( n_dim=self.n_dim, e_dim=self.e_dim, a_dim=self.a_dim, + a_compress_rate=self.a_compress_rate, axis_neuron=self.axis_neuron, update_angle=self.update_angle, activation_function=self.activation_function, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b8c1403f2b..bf78e195b8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1440,6 +1440,11 @@ def dpa3_repflow_args(): doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\ - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_compress_rate = ( + "The compression rate for angular messages. The default value is 0, indicating no compression. " + " If a non-zero integer c is provided, the node and edge dimensions will be compressed " + "to n_dim/c and e_dim/2c, respectively, within the angular message." + ) doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops." doc_update_angle = ( "Where to update the angle rep. If not, only node and edge rep will be used." @@ -1475,6 +1480,9 @@ def dpa3_repflow_args(): Argument("a_rcut", float, doc=doc_a_rcut), Argument("a_rcut_smth", float, doc=doc_a_rcut_smth), Argument("a_sel", [int, str], doc=doc_a_sel), + Argument( + "a_compress_rate", int, optional=True, default=0, doc=doc_a_compress_rate + ), Argument( "axis_neuron", int, diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index 300ff44bdd..701726a631 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -48,12 +48,14 @@ def test_consistency( ua, rus, ruri, + acr, prec, ect, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate ["float64"], # precision [False], # use_econf_tebd ): @@ -73,6 +75,7 @@ def test_consistency( a_rcut=self.rcut - 0.1, a_rcut_smth=self.rcut_smth, a_sel=nnei - 1, + a_compress_rate=acr, axis_neuron=4, update_angle=ua, update_style=rus, @@ -127,12 +130,14 @@ def test_jit( ua, rus, ruri, + acr, prec, ect, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate ["float64"], # precision [False], # use_econf_tebd ): @@ -150,6 +155,7 @@ def test_jit( a_rcut=self.rcut - 0.1, a_rcut_smth=self.rcut_smth, a_sel=nnei - 1, + a_compress_rate=acr, axis_neuron=4, update_angle=ua, update_style=rus, diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 1d55ee085b..df13a7ff92 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -475,6 +475,7 @@ def DescriptorParamDPA3( update_residual=0.1, update_residual_init="const", update_angle=True, + a_compress_rate=0, precision="float64", ): input_dict = { @@ -491,6 +492,7 @@ def DescriptorParamDPA3( "a_rcut": rcut / 2, "a_rcut_smth": rcut_smth / 2, "a_sel": sum(sel) // 4, + "a_compress_rate": a_compress_rate, "axis_neuron": 4, "update_angle": update_angle, "update_style": update_style, @@ -520,6 +522,7 @@ def DescriptorParamDPA3( "update_residual_init": ("const",), "exclude_types": ([], [[0, 1]]), "update_angle": (True, False), + "a_compress_rate": (0, 1), "env_protection": (0.0, 1e-8), "precision": ("float64",), }