diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py new file mode 100644 index 0000000000..e1e8632b0e --- /dev/null +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +class RepFlowArgs: + def __init__( + self, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + nlayers: int = 6, + e_rcut: float = 6.0, + e_rcut_smth: float = 5.0, + e_sel: int = 120, + a_rcut: float = 4.0, + a_rcut_smth: float = 3.5, + a_sel: int = 20, + a_compress_rate: int = 0, + axis_neuron: int = 4, + update_angle: bool = True, + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + ) -> None: + r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to n_dim/c and e_dim/2c, respectively, within the angular message. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + """ + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.nlayers = nlayers + self.e_rcut = e_rcut + self.e_rcut_smth = e_rcut_smth + self.e_sel = e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.a_compress_rate = a_compress_rate + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + else: + raise KeyError(key) + + def serialize(self) -> dict: + return { + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "nlayers": self.nlayers, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "a_compress_rate": self.a_compress_rate, + "axis_neuron": self.axis_neuron, + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + } + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowArgs": + return cls(**data) diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 327d75c2cd..b66f4a5b09 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -187,28 +187,26 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): ) # more_loss['log_keys'].append('rmse_e') else: # use l1 and for all atoms + energy_pred = energy_pred * atom_norm + energy_label = energy_label * atom_norm l1_ener_loss = F.l1_loss( energy_pred.reshape(-1), energy_label.reshape(-1), - reduction="sum", + reduction="mean", ) loss += pref_e * l1_ener_loss more_loss["mae_e"] = self.display_if_exist( - F.l1_loss( - energy_pred.reshape(-1), - energy_label.reshape(-1), - reduction="mean", - ).detach(), + l1_ener_loss.detach(), find_energy, ) # more_loss['log_keys'].append('rmse_e') - if mae: - mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm - more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) - mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) - more_loss["mae_e_all"] = self.display_if_exist( - mae_e_all.detach(), find_energy - ) + # if mae: + # mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm + # more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) + # mae_e_all = torch.mean(torch.abs(energy_pred - energy_label)) + # more_loss["mae_e_all"] = self.display_if_exist( + # mae_e_all.detach(), find_energy + # ) if ( (self.has_f or self.has_pf or self.relative_f or self.has_gf) @@ -241,17 +239,17 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): rmse_f.detach(), find_force ) else: - l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none") + l1_force_loss = F.l1_loss(force_label, force_pred, reduction="mean") more_loss["mae_f"] = self.display_if_exist( - l1_force_loss.mean().detach(), find_force + l1_force_loss.detach(), find_force ) - l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() + # l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum() loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION) - if mae: - mae_f = torch.mean(torch.abs(diff_f)) - more_loss["mae_f"] = self.display_if_exist( - mae_f.detach(), find_force - ) + # if mae: + # mae_f = torch.mean(torch.abs(diff_f)) + # more_loss["mae_f"] = self.display_if_exist( + # mae_f.detach(), find_force + # ) if self.has_pf and "atom_pref" in label: atom_pref = label["atom_pref"] @@ -297,18 +295,29 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): if self.has_v and "virial" in model_pred and "virial" in label: find_virial = label.get("find_virial", 0.0) pref_v = pref_v * find_virial + virial_label = label["virial"] + virial_pred = model_pred["virial"].reshape(-1, 9) diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) - l2_virial_loss = torch.mean(torch.square(diff_v)) - if not self.inference: - more_loss["l2_virial_loss"] = self.display_if_exist( - l2_virial_loss.detach(), find_virial + if not self.use_l1_all: + l2_virial_loss = torch.mean(torch.square(diff_v)) + if not self.inference: + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) + loss += atom_norm * (pref_v * l2_virial_loss) + rmse_v = l2_virial_loss.sqrt() * atom_norm + more_loss["rmse_v"] = self.display_if_exist( + rmse_v.detach(), find_virial + ) + else: + l1_virial_loss = F.l1_loss(virial_label, virial_pred, reduction="mean") + more_loss["mae_v"] = self.display_if_exist( + l1_virial_loss.detach(), find_virial ) - loss += atom_norm * (pref_v * l2_virial_loss) - rmse_v = l2_virial_loss.sqrt() * atom_norm - more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) - if mae: - mae_v = torch.mean(torch.abs(diff_v)) * atom_norm - more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) + loss += (pref_v * l1_virial_loss).to(GLOBAL_PT_FLOAT_PRECISION) + # if mae: + # mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + # more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label: atom_ener = model_pred["atom_energy"] diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 4a227918fe..9f3468d1db 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -13,6 +13,9 @@ from .dpa2 import ( DescrptDPA2, ) +from .dpa3 import ( + DescrptDPA3, +) from .env_mat import ( prod_env_mat, ) @@ -49,6 +52,7 @@ "DescrptBlockSeTTebd", "DescrptDPA1", "DescrptDPA2", + "DescrptDPA3", "DescrptHybrid", "DescrptSeA", "DescrptSeAttenV2", diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py new file mode 100644 index 0000000000..c5cfd9cb89 --- /dev/null +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.model.network.network import ( + TypeEmbedNet, + TypeEmbedNetConsistent, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_pair_exclude_types, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .descriptor import ( + extend_descrpt_stat, +) +from .repflow_layer import ( + RepFlowLayer, +) +from .repflows import ( + DescrptBlockRepflows, +) + + +@BaseDescriptor.register("dpa3") +class DescrptDPA3(BaseDescriptor, torch.nn.Module): + def __init__( + self, + ntypes: int, + # args for repflow + repflow: Union[RepFlowArgs, dict], + # kwargs for descriptor + concat_output_tebd: bool = False, + activation_function: str = "silu", + precision: str = "float64", + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + trainable: bool = True, + seed: Optional[Union[int, list[int]]] = None, + use_econf_tebd: bool = False, + use_tebd_bias: bool = False, + type_map: Optional[list[str]] = None, + ) -> None: + r"""The DPA-3 descriptor. + + Parameters + ---------- + repflow : Union[RepFlowArgs, dict] + The arguments used to initialize the repflow block, see docstr in `RepFlowArgs` for details information. + concat_output_tebd : bool, optional + Whether to concat type embedding at the output of the descriptor. + activation_function : str, optional + The activation function in the embedding net. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + trainable : bool, optional + If the parameters are trainable. + seed : int, optional + Random seed for parameter initialization. + use_econf_tebd : bool, Optional + Whether to use electronic configuration type embedding. + use_tebd_bias : bool, Optional + Whether to use bias in the type embedding layer. + type_map : list[str], Optional + A list of strings. Give the name to each type of atoms. + + Returns + ------- + descriptor: torch.Tensor + the descriptor of shape nb x nloc x n_dim. + invariant single-atom representation. + g2: torch.Tensor + invariant pair-atom representation. + h2: torch.Tensor + equivariant pair-atom representation. + rot_mat: torch.Tensor + rotation matrix for equivariant fittings + sw: torch.Tensor + The switch function for decaying inverse distance. + + """ + super().__init__() + + def init_subclass_params(sub_data, sub_class): + if isinstance(sub_data, dict): + return sub_class(**sub_data) + elif isinstance(sub_data, sub_class): + return sub_data + else: + raise ValueError( + f"Input args must be a {sub_class.__name__} class or a dict!" + ) + + self.repflow_args = init_subclass_params(repflow, RepFlowArgs) + self.activation_function = activation_function + + self.repflows = DescrptBlockRepflows( + self.repflow_args.e_rcut, + self.repflow_args.e_rcut_smth, + self.repflow_args.e_sel, + self.repflow_args.a_rcut, + self.repflow_args.a_rcut_smth, + self.repflow_args.a_sel, + ntypes, + nlayers=self.repflow_args.nlayers, + n_dim=self.repflow_args.n_dim, + e_dim=self.repflow_args.e_dim, + a_dim=self.repflow_args.a_dim, + a_compress_rate=self.repflow_args.a_compress_rate, + axis_neuron=self.repflow_args.axis_neuron, + update_angle=self.repflow_args.update_angle, + activation_function=self.activation_function, + update_style=self.repflow_args.update_style, + update_residual=self.repflow_args.update_residual, + update_residual_init=self.repflow_args.update_residual_init, + exclude_types=exclude_types, + env_protection=env_protection, + precision=precision, + seed=child_seed(seed, 1), + ) + + self.use_econf_tebd = use_econf_tebd + self.use_tebd_bias = use_tebd_bias + self.type_map = type_map + self.tebd_dim = self.repflow_args.n_dim + self.type_embedding = TypeEmbedNet( + ntypes, + self.tebd_dim, + precision=precision, + seed=child_seed(seed, 2), + use_econf_tebd=self.use_econf_tebd, + use_tebd_bias=use_tebd_bias, + type_map=type_map, + ) + self.concat_output_tebd = concat_output_tebd + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.exclude_types = exclude_types + self.env_protection = env_protection + self.trainable = trainable + + assert self.repflows.e_rcut > self.repflows.a_rcut + assert self.repflows.e_sel > self.repflows.a_sel + + self.rcut = self.repflows.get_rcut() + self.rcut_smth = self.repflows.get_rcut_smth() + self.sel = self.repflows.get_sel() + self.ntypes = ntypes + + # set trainable + for param in self.parameters(): + param.requires_grad = trainable + self.compress = False + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def get_dim_out(self) -> int: + """Returns the output dimension of this descriptor.""" + ret = self.repflows.dim_out + if self.concat_output_tebd: + ret += self.tebd_dim + return ret + + def get_dim_emb(self) -> int: + """Returns the embedding dimension of this descriptor.""" + return self.repflows.dim_emb + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.repflows.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.repflows.get_env_protection() + + def share_params(self, base_class, shared_level, resume=False) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + If not start from checkpoint (resume is False), + some separated parameters (e.g. mean and stddev) will be re-calculated across different classes. + """ + assert ( + self.__class__ == base_class.__class__ + ), "Only descriptors of the same type can share params!" + # For DPA3 descriptors, the user-defined share-level + # shared_level: 0 + # share all parameters in type_embedding, repflow + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + self.repflows.share_params(base_class.repflows, 0, resume=resume) + # shared_level: 1 + # share all parameters in type_embedding + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + # Other shared levels + else: + raise NotImplementedError + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert ( + self.type_map is not None + ), "'type_map' must be defined when performing type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.type_embedding.change_type_map(type_map=type_map) + self.exclude_types = map_pair_exclude_types(self.exclude_types, remap_index) + self.ntypes = len(type_map) + repflow = self.repflows + if has_new_type: + # the avg and std of new types need to be updated + extend_descrpt_stat( + repflow, + type_map, + des_with_stat=model_with_new_type_stat.repflows + if model_with_new_type_stat is not None + else None, + ) + repflow.ntypes = self.ntypes + repflow.reinit_exclude(self.exclude_types) + repflow["davg"] = repflow["davg"][remap_index] + repflow["dstd"] = repflow["dstd"][remap_index] + + @property + def dim_out(self): + return self.get_dim_out() + + @property + def dim_emb(self): + """Returns the embedding dimension g2.""" + return self.get_dim_emb() + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.compute_input_stats(merged, path) + + def set_stat_mean_and_stddev( + self, + mean: list[torch.Tensor], + stddev: list[torch.Tensor], + ) -> None: + """Update mean and stddev for descriptor.""" + descrpt_list = [self.repflows] + for ii, descrpt in enumerate(descrpt_list): + descrpt.mean = mean[ii] + descrpt.stddev = stddev[ii] + + def get_stat_mean_and_stddev(self) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Get mean and stddev for descriptor.""" + mean_list = [self.repflows.mean] + stddev_list = [self.repflows.stddev] + return mean_list, stddev_list + + def serialize(self) -> dict: + repflows = self.repflows + data = { + "@class": "Descriptor", + "type": "dpa3", + "@version": 1, + "ntypes": self.ntypes, + "repflow_args": self.repflow_args.serialize(), + "concat_output_tebd": self.concat_output_tebd, + "activation_function": self.activation_function, + "precision": self.precision, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "trainable": self.trainable, + "use_econf_tebd": self.use_econf_tebd, + "use_tebd_bias": self.use_tebd_bias, + "type_map": self.type_map, + "type_embedding": self.type_embedding.embedding.serialize(), + } + repflow_variable = { + "edge_embd": repflows.edge_embd.serialize(), + "angle_embd": repflows.angle_embd.serialize(), + "repflow_layers": [layer.serialize() for layer in repflows.layers], + "env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(repflows["davg"]), + "dstd": to_numpy_array(repflows["dstd"]), + }, + } + data.update( + { + "repflow_variable": repflow_variable, + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA3": + data = data.copy() + version = data.pop("@version") + check_version_compatibility(version, 1, 1) + data.pop("@class") + data.pop("type") + repflow_variable = data.pop("repflow_variable").copy() + type_embedding = data.pop("type_embedding") + data["repflow"] = RepFlowArgs(**data.pop("repflow_args")) + obj = cls(**data) + obj.type_embedding.embedding = TypeEmbedNetConsistent.deserialize( + type_embedding + ) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.repflows.prec, device=env.DEVICE) + + # deserialize repflow + statistic_repflows = repflow_variable.pop("@variables") + env_mat = repflow_variable.pop("env_mat") + repflow_layers = repflow_variable.pop("repflow_layers") + obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd")) + obj.repflows.angle_embd = MLPLayer.deserialize( + repflow_variable.pop("angle_embd") + ) + obj.repflows["davg"] = t_cvt(statistic_repflows["davg"]) + obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"]) + obj.repflows.layers = torch.nn.ModuleList( + [RepFlowLayer.deserialize(layer) for layer in repflow_layers] + ) + return obj + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + extended_coord + The extended coordinates of atoms. shape: nf x (nallx3) + extended_atype + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, mapps extended region index to local region. + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + node_ebd + The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim) + rot_mat + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x e_dim x 3 + edge_ebd + The edge embedding. + shape: nf x nloc x nnei x e_dim + h2 + The rotationally equivariant pair-partical representation. + shape: nf x nloc x nnei x 3 + sw + The smooth switch function. shape: nf x nloc x nnei + + """ + # cast the input to internal precsion + extended_coord = extended_coord.to(dtype=self.prec) + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + + node_ebd_ext = self.type_embedding(extended_atype) + node_ebd_inp = node_ebd_ext[:, :nloc, :] + # repflows + node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows( + nlist, + extended_coord, + extended_atype, + node_ebd_ext, + mapping, + comm_dict=comm_dict, + ) + if self.concat_output_tebd: + node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1) + return ( + node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + ) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + update_sel = UpdateSel() + min_nbor_dist, repflow_e_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["e_rcut"], + local_jdata_cpy["repflow"]["e_sel"], + True, + ) + local_jdata_cpy["repflow"]["e_sel"] = repflow_e_sel[0] + + min_nbor_dist, repflow_a_sel = update_sel.update_one_sel( + train_data, + type_map, + local_jdata_cpy["repflow"]["a_rcut"], + local_jdata_cpy["repflow"]["a_sel"], + True, + ) + local_jdata_cpy["repflow"]["a_sel"] = repflow_a_sel[0] + + return local_jdata_cpy, min_nbor_dist + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("Compression is unsupported for DPA3.") diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py new file mode 100644 index 0000000000..94c4945c76 --- /dev/null +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -0,0 +1,744 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, + Union, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor.repformer_layer import ( + _apply_nlist_mask, + _apply_switch, + _make_nei_g1, + get_residual, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +class RepFlowLayer(torch.nn.Module): + def __init__( + self, + e_rcut: float, + e_rcut_smth: float, + e_sel: int, + a_rcut: float, + a_rcut_smth: float, + a_sel: int, + ntypes: int, + n_dim: int = 128, + e_dim: int = 16, + a_dim: int = 64, + a_compress_rate: int = 0, + axis_neuron: int = 4, + update_angle: bool = True, # angle + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + super().__init__() + self.epsilon = 1e-4 # protection of 1./nnei + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.ntypes = ntypes + e_sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(e_sel) + assert len(e_sel) == 1 + self.e_sel = e_sel + self.sec = self.e_sel + self.a_rcut = a_rcut + self.a_rcut_smth = a_rcut_smth + self.a_sel = a_sel + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.a_compress_rate = a_compress_rate + if a_compress_rate != 0: + assert a_dim % (2 * a_compress_rate) == 0, ( + f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. " + f"Currently, a_dim={a_dim} is not valid." + ) + self.axis_neuron = axis_neuron + self.update_angle = update_angle + self.activation_function = activation_function + self.act = ActivationFn(activation_function) + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.precision = precision + self.seed = seed + self.prec = PRECISION_DICT[precision] + + assert update_residual_init in [ + "norm", + "const", + ], "'update_residual_init' only support 'norm' or 'const'!" + + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.n_residual = [] + self.e_residual = [] + self.a_residual = [] + self.edge_info_dim = self.n_dim * 2 + self.e_dim + + # node self mlp + self.node_self_mlp = MLPLayer( + n_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 0), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 1), + ) + ) + + # node sym (grrg + drrd) + self.n_sym_dim = n_dim * self.axis_neuron + e_dim * self.axis_neuron + self.node_sym_linear = MLPLayer( + self.n_sym_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 2), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 3), + ) + ) + + # node edge message + self.node_edge_linear = MLPLayer( + self.edge_info_dim, + n_dim, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.n_residual.append( + get_residual( + n_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 5), + ) + ) + + # edge self message + self.edge_self_linear = MLPLayer( + self.edge_info_dim, + e_dim, + precision=precision, + seed=child_seed(seed, 6), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 7), + ) + ) + + if self.update_angle: + self.angle_dim = self.a_dim + if self.a_compress_rate == 0: + # angle + node + edge * 2 + self.angle_dim += self.n_dim + 2 * self.e_dim + self.a_compress_n_linear = None + self.a_compress_e_linear = None + else: + # angle + node/c + edge/2c * 2 + self.angle_dim += 2 * (self.a_dim // self.a_compress_rate) + self.a_compress_n_linear = MLPLayer( + self.n_dim, + self.a_dim // self.a_compress_rate, + precision=precision, + bias=False, + seed=child_seed(seed, 8), + ) + self.a_compress_e_linear = MLPLayer( + self.e_dim, + self.a_dim // (2 * self.a_compress_rate), + precision=precision, + bias=False, + seed=child_seed(seed, 9), + ) + + # edge angle message + self.edge_angle_linear1 = MLPLayer( + self.angle_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 10), + ) + self.edge_angle_linear2 = MLPLayer( + self.e_dim, + self.e_dim, + precision=precision, + seed=child_seed(seed, 11), + ) + if self.update_style == "res_residual": + self.e_residual.append( + get_residual( + self.e_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 12), + ) + ) + + # angle self message + self.angle_self_linear = MLPLayer( + self.angle_dim, + self.a_dim, + precision=precision, + seed=child_seed(seed, 13), + ) + if self.update_style == "res_residual": + self.a_residual.append( + get_residual( + self.a_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 14), + ) + ) + else: + self.angle_self_linear = None + self.edge_angle_linear1 = None + self.edge_angle_linear2 = None + self.a_compress_n_linear = None + self.a_compress_e_linear = None + self.angle_dim = 0 + + self.n_residual = nn.ParameterList(self.n_residual) + self.e_residual = nn.ParameterList(self.e_residual) + self.a_residual = nn.ParameterList(self.a_residual) + + @staticmethod + def _cal_hg( + edge_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + ) -> torch.Tensor: + """ + Calculate the transposed rotation matrix. + + Parameters + ---------- + edge_ebd + Neighbor-wise/Pair-wise edge embeddings, with shape nb x nloc x nnei x e_dim. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + + Returns + ------- + hg + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. + """ + # edge_ebd: nb x nloc x nnei x e_dim + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = edge_ebd.shape + e_dim = edge_ebd.shape[-1] + # nb x nloc x nnei x e_dim + edge_ebd = _apply_nlist_mask(edge_ebd, nlist_mask) + edge_ebd = _apply_switch(edge_ebd, sw) + invnnei = torch.rsqrt( + float(nnei) + * torch.ones((nb, nloc, 1, 1), dtype=edge_ebd.dtype, device=edge_ebd.device) + ) + # nb x nloc x 3 x e_dim + h2g2 = torch.matmul(torch.transpose(h2, -1, -2), edge_ebd) * invnnei + return h2g2 + + @staticmethod + def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor: + """ + Calculate the atomic invariant rep. + + Parameters + ---------- + h2g2 + The transposed rotation matrix, with shape nb x nloc x 3 x e_dim. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # nb x nloc x 3 x e_dim + nb, nloc, _, e_dim = h2g2.shape + # nb x nloc x 3 x axis + h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0] + # nb x nloc x axis x e_dim + g1_13 = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1) + # nb x nloc x (axisxng2) + g1_13 = g1_13.view(nb, nloc, axis_neuron * e_dim) + return g1_13 + + def symmetrization_op( + self, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + axis_neuron: int, + ) -> torch.Tensor: + """ + Symmetrization operator to obtain atomic invariant rep. + + Parameters + ---------- + edge_ebd + Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x e_dim. + h2 + Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3. + nlist_mask + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. + sw + The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, + and remains 0 beyond rcut, with shape nb x nloc x nnei. + axis_neuron + Size of the submatrix. + + Returns + ------- + grrg + Atomic invariant rep, with shape nb x nloc x (axis_neuron x e_dim) + """ + # edge_ebd: nb x nloc x nnei x e_dim + # h2: nb x nloc x nnei x 3 + # msk: nb x nloc x nnei + nb, nloc, nnei, _ = edge_ebd.shape + # nb x nloc x 3 x e_dim + h2g2 = self._cal_hg( + edge_ebd, + h2, + nlist_mask, + sw, + ) + # nb x nloc x (axisxng2) + g1_13 = self._cal_grrg(h2g2, axis_neuron) + return g1_13 + + def forward( + self, + node_ebd_ext: torch.Tensor, # nf x nall x n_dim + edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim + h2: torch.Tensor, # nf x nloc x nnei x 3 + angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + a_nlist: torch.Tensor, # nf x nloc x a_nnei + a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei + a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + ): + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + nb, nloc, nnei, _ = edge_ebd.shape + nall = node_ebd_ext.shape[1] + node_ebd, _ = torch.split(node_ebd_ext, [nloc, nall - nloc], dim=1) + assert (nb, nloc) == node_ebd.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] + del a_nlist # may be used in the future + + n_update_list: list[torch.Tensor] = [node_ebd] + e_update_list: list[torch.Tensor] = [edge_ebd] + a_update_list: list[torch.Tensor] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + nei_node_ebd = _make_nei_g1(node_ebd_ext, nlist) + + # node sym (grrg + drrd) + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + + # nb x nloc x nnei x (n_dim * 2 + e_dim) + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + + # node edge message + # nb x nloc x nnei x n_dim + node_edge_update = self.act(self.node_edge_linear(edge_info)) * sw.unsqueeze(-1) + node_edge_update = torch.sum(node_edge_update, dim=-2) / self.nnei + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") + + # edge self message + edge_self_update = self.act(self.edge_self_linear(edge_info)) + e_update_list.append(edge_self_update) + + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + # get angle info + if self.a_compress_rate != 0: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + + # nb x nloc x a_nnei x a_nnei x n_dim + node_for_angle_info = torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + # nb x nloc x a_nnei x e_dim + edge_for_angle = edge_ebd_for_angle[:, :, : self.a_sel, :] + # nb x nloc x a_nnei x e_dim + edge_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_for_angle, 0.0 + ) + # nb x nloc x (a_nnei) x a_nnei x edge_ebd + edge_for_angle_i = torch.tile( + edge_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) + ) + # nb x nloc x a_nnei x (a_nnei) x e_dim + edge_for_angle_j = torch.tile( + edge_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) + ) + # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) + edge_for_angle_info = torch.cat( + [edge_for_angle_i, edge_for_angle_j], dim=-1 + ) + angle_info_list = [angle_ebd, node_for_angle_info, edge_for_angle_info] + # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) + angle_info = torch.cat(angle_info_list, dim=-1) + + # edge angle message + # nb x nloc x a_nnei x a_nnei x e_dim + edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) + # nb x nloc x a_nnei x a_nnei x e_dim + weighted_edge_angle_update = ( + edge_angle_update + * a_sw[:, :, :, None, None] + * a_sw[:, :, None, :, None] + ) + # nb x nloc x a_nnei x e_dim + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + # nb x nloc x nnei x e_dim + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + ) + e_update_list.append( + self.act(self.edge_angle_linear2(padding_edge_angle_update)) + ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # angle self message + # nb x nloc x a_nnei x a_nnei x dim_a + angle_self_update = self.act(self.angle_self_linear(angle_info)) + a_update_list.append(angle_self_update) + else: + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated + + @torch.jit.export + def list_update_res_avg( + self, + update_list: list[torch.Tensor], + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + for ii in range(1, nitem): + uu = uu + update_list[ii] + return uu / (float(nitem) ** 0.5) + + @torch.jit.export + def list_update_res_incr(self, update_list: list[torch.Tensor]) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + scale = 1.0 / (float(nitem - 1) ** 0.5) if nitem > 1 else 0.0 + for ii in range(1, nitem): + uu = uu + scale * update_list[ii] + return uu + + @torch.jit.export + def list_update_res_residual( + self, update_list: list[torch.Tensor], update_name: str = "node" + ) -> torch.Tensor: + nitem = len(update_list) + uu = update_list[0] + # make jit happy + if update_name == "node": + for ii, vv in enumerate(self.n_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "edge": + for ii, vv in enumerate(self.e_residual): + uu = uu + vv * update_list[ii + 1] + elif update_name == "angle": + for ii, vv in enumerate(self.a_residual): + uu = uu + vv * update_list[ii + 1] + else: + raise NotImplementedError + return uu + + @torch.jit.export + def list_update( + self, update_list: list[torch.Tensor], update_name: str = "node" + ) -> torch.Tensor: + if self.update_style == "res_avg": + return self.list_update_res_avg(update_list) + elif self.update_style == "res_incr": + return self.list_update_res_incr(update_list) + elif self.update_style == "res_residual": + return self.list_update_res_residual(update_list, update_name=update_name) + else: + raise RuntimeError(f"unknown update style {self.update_style}") + + def serialize(self) -> dict: + """Serialize the networks to a dict. + + Returns + ------- + dict + The serialized networks. + """ + data = { + "@class": "RepformerLayer", + "@version": 1, + "e_rcut": self.e_rcut, + "e_rcut_smth": self.e_rcut_smth, + "e_sel": self.e_sel, + "a_rcut": self.a_rcut, + "a_rcut_smth": self.a_rcut_smth, + "a_sel": self.a_sel, + "ntypes": self.ntypes, + "n_dim": self.n_dim, + "e_dim": self.e_dim, + "a_dim": self.a_dim, + "a_compress_rate": self.a_compress_rate, + "axis_neuron": self.axis_neuron, + "activation_function": self.activation_function, + "update_angle": self.update_angle, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "precision": self.precision, + "node_self_mlp": self.node_self_mlp.serialize(), + "node_sym_linear": self.node_sym_linear.serialize(), + "node_edge_linear": self.node_edge_linear.serialize(), + "edge_self_linear": self.edge_self_linear.serialize(), + } + if self.update_angle: + data.update( + { + "edge_angle_linear1": self.edge_angle_linear1.serialize(), + "edge_angle_linear2": self.edge_angle_linear2.serialize(), + "angle_self_linear": self.angle_self_linear.serialize(), + } + ) + if self.a_compress_rate != 0: + data.update( + { + "a_compress_n_linear": self.a_compress_n_linear.serialize(), + "a_compress_e_linear": self.a_compress_e_linear.serialize(), + } + ) + if self.update_style == "res_residual": + data.update( + { + "@variables": { + "n_residual": [to_numpy_array(t) for t in self.n_residual], + "e_residual": [to_numpy_array(t) for t in self.e_residual], + "a_residual": [to_numpy_array(t) for t in self.a_residual], + } + } + ) + return data + + @classmethod + def deserialize(cls, data: dict) -> "RepFlowLayer": + """Deserialize the networks from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + update_angle = data["update_angle"] + a_compress_rate = data["a_compress_rate"] + node_self_mlp = data.pop("node_self_mlp") + node_sym_linear = data.pop("node_sym_linear") + node_edge_linear = data.pop("node_edge_linear") + edge_self_linear = data.pop("edge_self_linear") + edge_angle_linear1 = data.pop("edge_angle_linear1", None) + edge_angle_linear2 = data.pop("edge_angle_linear2", None) + angle_self_linear = data.pop("angle_self_linear", None) + a_compress_n_linear = data.pop("a_compress_n_linear", None) + a_compress_e_linear = data.pop("a_compress_e_linear", None) + update_style = data["update_style"] + variables = data.pop("@variables", {}) + n_residual = variables.get("n_residual", data.pop("n_residual", [])) + e_residual = variables.get("e_residual", data.pop("e_residual", [])) + a_residual = variables.get("a_residual", data.pop("a_residual", [])) + + obj = cls(**data) + obj.node_self_mlp = MLPLayer.deserialize(node_self_mlp) + obj.node_sym_linear = MLPLayer.deserialize(node_sym_linear) + obj.node_edge_linear = MLPLayer.deserialize(node_edge_linear) + obj.edge_self_linear = MLPLayer.deserialize(edge_self_linear) + + if update_angle: + assert isinstance(edge_angle_linear1, dict) + assert isinstance(edge_angle_linear2, dict) + assert isinstance(angle_self_linear, dict) + obj.edge_angle_linear1 = MLPLayer.deserialize(edge_angle_linear1) + obj.edge_angle_linear2 = MLPLayer.deserialize(edge_angle_linear2) + obj.angle_self_linear = MLPLayer.deserialize(angle_self_linear) + if a_compress_rate != 0: + assert isinstance(a_compress_n_linear, dict) + assert isinstance(a_compress_e_linear, dict) + obj.a_compress_n_linear = MLPLayer.deserialize(a_compress_n_linear) + obj.a_compress_e_linear = MLPLayer.deserialize(a_compress_e_linear) + + if update_style == "res_residual": + for ii, t in enumerate(obj.n_residual): + t.data = to_torch_tensor(n_residual[ii]) + for ii, t in enumerate(obj.e_residual): + t.data = to_torch_tensor(e_residual[ii]) + for ii, t in enumerate(obj.a_residual): + t.data = to_torch_tensor(a_residual[ii]) + return obj diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py new file mode 100644 index 0000000000..e8ad3a78e0 --- /dev/null +++ b/deepmd/pt/model/descriptor/repflows.py @@ -0,0 +1,567 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Callable, + Optional, + Union, +) + +import torch + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.descriptor.descriptor import ( + DescriptorBlock, +) +from deepmd.pt.model.descriptor.env_mat import ( + prod_env_mat, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt.utils.spin import ( + concat_switch_virtual, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .repflow_layer import ( + RepFlowLayer, +) + +if not hasattr(torch.ops.deepmd, "border_op"): + + def border_op( + argument0, + argument1, + argument2, + argument3, + argument4, + argument5, + argument6, + argument7, + argument8, + ) -> torch.Tensor: + raise NotImplementedError( + "border_op is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for DPA-3 for details." + ) + + # Note: this hack cannot actually save a model that can be run using LAMMPS. + torch.ops.deepmd.border_op = border_op + + +@DescriptorBlock.register("se_repflow") +class DescrptBlockRepflows(DescriptorBlock): + def __init__( + self, + e_rcut, + e_rcut_smth, + e_sel: int, + a_rcut, + a_rcut_smth, + a_sel: int, + ntypes: int, + nlayers: int = 6, + n_dim: int = 128, + e_dim: int = 64, + a_dim: int = 64, + a_compress_rate: int = 0, + axis_neuron: int = 4, + update_angle: bool = True, + activation_function: str = "silu", + update_style: str = "res_residual", + update_residual: float = 0.1, + update_residual_init: str = "const", + set_davg_zero: bool = True, + exclude_types: list[tuple[int, int]] = [], + env_protection: float = 0.0, + precision: str = "float64", + seed: Optional[Union[int, list[int]]] = None, + ) -> None: + r""" + The repflow descriptor block. + + Parameters + ---------- + n_dim : int, optional + The dimension of node representation. + e_dim : int, optional + The dimension of edge representation. + a_dim : int, optional + The dimension of angle representation. + nlayers : int, optional + Number of repflow layers. + e_rcut : float, optional + The edge cut-off radius. + e_rcut_smth : float, optional + Where to start smoothing for edge. For example the 1/r term is smoothed from rcut to rcut_smth. + e_sel : int, optional + Maximally possible number of selected edge neighbors. + a_rcut : float, optional + The angle cut-off radius. + a_rcut_smth : float, optional + Where to start smoothing for angle. For example the 1/r term is smoothed from rcut to rcut_smth. + a_sel : int, optional + Maximally possible number of selected angle neighbors. + a_compress_rate : int, optional + The compression rate for angular messages. The default value is 0, indicating no compression. + If a non-zero integer c is provided, the node and edge dimensions will be compressed + to n_dim/c and e_dim/2c, respectively, within the angular message. + axis_neuron : int, optional + The number of dimension of submatrix in the symmetrization ops. + update_angle : bool, optional + Where to update the angle rep. If not, only node and edge rep will be used. + update_style : str, optional + Style to update a representation. + Supported options are: + -'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) + -'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n) + -'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) + where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` + and `update_residual_init`. + update_residual : float, optional + When update using residual mode, the initial std of residual vector weights. + update_residual_init : str, optional + When update using residual mode, the initialization mode of residual vector weights. + ntypes : int + Number of element types + activation_function : str, optional + The activation function in the embedding net. + set_davg_zero : bool, optional + Set the normalization average to zero. + precision : str, optional + The precision of the embedding net parameters. + exclude_types : list[list[int]], optional + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + env_protection : float, optional + Protection parameter to prevent division by zero errors during environment matrix calculations. + For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. + seed : int, optional + Random seed for parameter initialization. + """ + super().__init__() + self.e_rcut = float(e_rcut) + self.e_rcut_smth = float(e_rcut_smth) + self.e_sel = e_sel + self.a_rcut = float(a_rcut) + self.a_rcut_smth = float(a_rcut_smth) + self.a_sel = a_sel + self.ntypes = ntypes + self.nlayers = nlayers + # for other common desciptor method + sel = [e_sel] if isinstance(e_sel, int) else e_sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. + assert len(sel) == 1 + self.sel = sel + self.rcut = e_rcut + self.rcut_smth = e_rcut_smth + self.sec = self.sel + self.split_sel = self.sel + self.a_compress_rate = a_compress_rate + self.axis_neuron = axis_neuron + self.set_davg_zero = set_davg_zero + + self.n_dim = n_dim + self.e_dim = e_dim + self.a_dim = a_dim + self.update_angle = update_angle + + self.activation_function = activation_function + self.update_style = update_style + self.update_residual = update_residual + self.update_residual_init = update_residual_init + self.act = ActivationFn(activation_function) + self.prec = PRECISION_DICT[precision] + + # order matters, placed after the assignment of self.ntypes + self.reinit_exclude(exclude_types) + self.env_protection = env_protection + self.precision = precision + self.epsilon = 1e-4 + self.seed = seed + + self.edge_embd = MLPLayer( + 1, self.e_dim, precision=precision, seed=child_seed(seed, 0) + ) + self.angle_embd = MLPLayer( + 1, self.a_dim, precision=precision, bias=False, seed=child_seed(seed, 1) + ) + layers = [] + for ii in range(nlayers): + layers.append( + RepFlowLayer( + e_rcut=self.e_rcut, + e_rcut_smth=self.e_rcut_smth, + e_sel=self.sel, + a_rcut=self.a_rcut, + a_rcut_smth=self.a_rcut_smth, + a_sel=self.a_sel, + ntypes=self.ntypes, + n_dim=self.n_dim, + e_dim=self.e_dim, + a_dim=self.a_dim, + a_compress_rate=self.a_compress_rate, + axis_neuron=self.axis_neuron, + update_angle=self.update_angle, + activation_function=self.activation_function, + update_style=self.update_style, + update_residual=self.update_residual, + update_residual_init=self.update_residual_init, + precision=precision, + seed=child_seed(child_seed(seed, 1), ii), + ) + ) + self.layers = torch.nn.ModuleList(layers) + + wanted_shape = (self.ntypes, self.nnei, 4) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.e_rcut + + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.e_rcut_smth + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.dim_out + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return self.dim_in + + def get_dim_emb(self) -> int: + """Returns the embedding dimension e_dim.""" + return self.e_dim + + def __setitem__(self, key, value) -> None: + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def mixed_types(self) -> bool: + """If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return True + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.n_dim + + @property + def dim_in(self): + """Returns the atomic input dimension of this descriptor.""" + return self.n_dim + + @property + def dim_emb(self): + """Returns the embedding dimension e_dim.""" + return self.get_dim_emb() + + def reinit_exclude( + self, + exclude_types: list[tuple[int, int]] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: Optional[torch.Tensor] = None, + mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ): + if comm_dict is None: + assert mapping is not None + assert extended_atype_embd is not None + nframes, nloc, nnei = nlist.shape + nall = extended_coord.view(nframes, -1).shape[1] // 3 + atype = extended_atype[:, :nloc] + # nb x nloc x nnei + exclude_mask = self.emask(nlist, extended_atype) + nlist = torch.where(exclude_mask != 0, nlist, -1) + # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 + dmatrix, diff, sw = prod_env_mat( + extended_coord, + nlist, + atype, + self.mean, + self.stddev, + self.e_rcut, + self.e_rcut_smth, + protection=self.env_protection, + ) + nlist_mask = nlist != -1 + sw = torch.squeeze(sw, -1) + # beyond the cutoff sw should be 0.0 + sw = sw.masked_fill(~nlist_mask, 0.0) + + # [nframes, nloc, tebd_dim] + if comm_dict is None: + assert isinstance(extended_atype_embd, torch.Tensor) # for jit + atype_embd = extended_atype_embd[:, :nloc, :] + assert list(atype_embd.shape) == [nframes, nloc, self.n_dim] + else: + atype_embd = extended_atype_embd + assert isinstance(atype_embd, torch.Tensor) # for jit + node_ebd = self.act(atype_embd) + n_dim = node_ebd.shape[-1] + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 + edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) + # nb x nloc x nnei x e_dim + edge_ebd = self.act(self.edge_embd(edge_input)) + + # get angle nlist (maybe smaller) + a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ + :, :, : self.a_sel + ] + a_nlist = nlist[:, :, : self.a_sel] + a_nlist = torch.where(a_dist_mask, a_nlist, -1) + _, a_diff, a_sw = prod_env_mat( + extended_coord, + a_nlist, + atype, + self.mean[:, : self.a_sel], + self.stddev[:, : self.a_sel], + self.a_rcut, + self.a_rcut_smth, + protection=self.env_protection, + ) + a_nlist_mask = a_nlist != -1 + a_sw = torch.squeeze(a_sw, -1) + # beyond the cutoff sw should be 0.0 + a_sw = a_sw.masked_fill(~a_nlist_mask, 0.0) + a_nlist[a_nlist == -1] = 0 + + # nf x nloc x a_nnei x 3 + normalized_diff_i = a_diff / ( + torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 + ) + # nf x nloc x 3 x a_nnei + normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) + # nf x nloc x a_nnei x a_nnei + # 1 - 1e-6 for torch.acos stability + cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6) + # nf x nloc x a_nnei x a_nnei x 1 + cosine_ij = cosine_ij.unsqueeze(-1) / (torch.pi**0.5) + # nf x nloc x a_nnei x a_nnei x a_dim + angle_ebd = self.angle_embd(cosine_ij).reshape( + nframes, nloc, self.a_sel, self.a_sel, self.a_dim + ) + + # set all padding positions to index of 0 + # if the a neighbor is real or not is indicated by nlist_mask + nlist[nlist == -1] = 0 + # nb x nall x n_dim + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim) + ) + for idx, ll in enumerate(self.layers): + # node_ebd: nb x nloc x n_dim + # node_ebd_ext: nb x nall x n_dim + if comm_dict is None: + assert mapping is not None + node_ebd_ext = torch.gather(node_ebd, 1, mapping) + else: + has_spin = "has_spin" in comm_dict + if not has_spin: + n_padding = nall - nloc + node_ebd = torch.nn.functional.pad( + node_ebd.squeeze(0), (0, 0, 0, n_padding), value=0.0 + ) + real_nloc = nloc + real_nall = nall + else: + # for spin + real_nloc = nloc // 2 + real_nall = nall // 2 + real_n_padding = real_nall - real_nloc + node_ebd_real, node_ebd_virtual = torch.split( + node_ebd, [real_nloc, real_nloc], dim=1 + ) + # mix_node_ebd: nb x real_nloc x (n_dim * 2) + mix_node_ebd = torch.cat([node_ebd_real, node_ebd_virtual], dim=2) + # nb x real_nall x (n_dim * 2) + node_ebd = torch.nn.functional.pad( + mix_node_ebd.squeeze(0), (0, 0, 0, real_n_padding), value=0.0 + ) + + assert "send_list" in comm_dict + assert "send_proc" in comm_dict + assert "recv_proc" in comm_dict + assert "send_num" in comm_dict + assert "recv_num" in comm_dict + assert "communicator" in comm_dict + ret = torch.ops.deepmd.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + node_ebd, + comm_dict["communicator"], + torch.tensor( + real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + torch.tensor( + real_nall - real_nloc, + dtype=torch.int32, + device=env.DEVICE, + ), # should be int of c++ + ) + node_ebd_ext = ret[0].unsqueeze(0) + if has_spin: + node_ebd_real_ext, node_ebd_virtual_ext = torch.split( + node_ebd_ext, [n_dim, n_dim], dim=2 + ) + node_ebd_ext = concat_switch_virtual( + node_ebd_real_ext, node_ebd_virtual_ext, real_nloc + ) + node_ebd, edge_ebd, angle_ebd = ll.forward( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + ) + + # nb x nloc x 3 x e_dim + h2g2 = RepFlowLayer._cal_hg(edge_ebd, h2, nlist_mask, sw) + # (nb x nloc) x e_dim x 3 + rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) + + return node_ebd, edge_ebd, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw + + def compute_input_stats( + self, + merged: Union[Callable[[], list[dict]], list[dict]], + path: Optional[DPPath] = None, + ) -> None: + """ + Compute the input statistics (e.g. mean and stddev) for the descriptors from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + path : Optional[DPPath] + The path to the stat file. + + """ + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + if path is None or not path.is_dir(): + if callable(merged): + # only get data for once + sampled = merged() + else: + sampled = merged + else: + sampled = [] + env_mat_stat.load_or_compute_stats(sampled, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_( + torch.tensor(mean, device=env.DEVICE, dtype=self.mean.dtype) + ) + self.stddev.copy_( + torch.tensor(stddev, device=env.DEVICE, dtype=self.stddev.dtype) + ) + + def get_stats(self) -> dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return True diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 6ce4f5d6fc..50d378455b 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -39,6 +39,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.softplus(x) elif self.activation.lower() == "sigmoid": return torch.sigmoid(x) + elif self.activation.lower() == "silu": + return F.silu(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5b57f15979..bf78e195b8 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1355,6 +1355,172 @@ def dpa2_repformer_args(): ] +@descrpt_args_plugin.register("dpa3", doc=doc_only_pt_supported) +def descrpt_dpa3_args(): + # repflow args + doc_repflow = "The arguments used to initialize the repflow block." + # descriptor args + doc_concat_output_tebd = ( + "Whether to concat type embedding at the output of the descriptor." + ) + doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." + doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1." + doc_env_protection = "Protection parameter to prevent division by zero errors during environment matrix calculations. For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection." + doc_trainable = "If the parameters in the embedding net is trainable." + doc_seed = "Random seed for parameter initialization." + doc_use_econf_tebd = "Whether to use electronic configuration type embedding." + doc_use_tebd_bias = "Whether to use bias in the type embedding layer." + return [ + # doc_repflow args + Argument("repflow", dict, dpa3_repflow_args(), doc=doc_repflow), + # descriptor args + Argument( + "concat_output_tebd", + bool, + optional=True, + default=False, + doc=doc_concat_output_tebd, + ), + Argument( + "activation_function", + str, + optional=True, + default="silu", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument( + "exclude_types", + list[list[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument( + "env_protection", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_env_protection, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "use_econf_tebd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_econf_tebd, + ), + Argument( + "use_tebd_bias", + bool, + optional=True, + default=False, + doc=doc_use_tebd_bias, + ), + ] + + +# repflow for dpa3 +def dpa3_repflow_args(): + # repflow args + doc_n_dim = "The dimension of node representation." + doc_e_dim = "The dimension of edge representation." + doc_a_dim = "The dimension of angle representation." + doc_nlayers = "The number of repflow layers." + doc_e_rcut = "The edge cut-off radius." + doc_e_rcut_smth = "Where to start smoothing for edge. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_e_sel = 'Maximally possible number of selected edge neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_rcut = "The angle cut-off radius." + doc_a_rcut_smth = "Where to start smoothing for angle. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." + doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_a_compress_rate = ( + "The compression rate for angular messages. The default value is 0, indicating no compression. " + " If a non-zero integer c is provided, the node and edge dimensions will be compressed " + "to n_dim/c and e_dim/2c, respectively, within the angular message." + ) + doc_axis_neuron = "The number of dimension of submatrix in the symmetrization ops." + doc_update_angle = ( + "Where to update the angle rep. If not, only node and edge rep will be used." + ) + doc_update_style = ( + "Style to update a representation. " + "Supported options are: " + "-'res_avg': Updates a rep `u` with: u = 1/\\sqrt{n+1} (u + u_1 + u_2 + ... + u_n) " + "-'res_incr': Updates a rep `u` with: u = u + 1/\\sqrt{n} (u_1 + u_2 + ... + u_n)" + "-'res_residual': Updates a rep `u` with: u = u + (r1*u_1 + r2*u_2 + ... + r3*u_n) " + "where `r1`, `r2` ... `r3` are residual weights defined by `update_residual` " + "and `update_residual_init`." + ) + doc_update_residual = ( + "When update using residual mode, " + "the initial std of residual vector weights." + ) + doc_update_residual_init = ( + "When update using residual mode, " + "the initialization mode of residual vector weights." + "Supported modes are: ['norm', 'const']." + ) + + return [ + # repflow args + Argument("n_dim", int, optional=True, default=128, doc=doc_n_dim), + Argument("e_dim", int, optional=True, default=64, doc=doc_e_dim), + Argument("a_dim", int, optional=True, default=64, doc=doc_a_dim), + Argument("nlayers", int, optional=True, default=6, doc=doc_nlayers), + Argument("e_rcut", float, doc=doc_e_rcut), + Argument("e_rcut_smth", float, doc=doc_e_rcut_smth), + Argument("e_sel", [int, str], doc=doc_e_sel), + Argument("a_rcut", float, doc=doc_a_rcut), + Argument("a_rcut_smth", float, doc=doc_a_rcut_smth), + Argument("a_sel", [int, str], doc=doc_a_sel), + Argument( + "a_compress_rate", int, optional=True, default=0, doc=doc_a_compress_rate + ), + Argument( + "axis_neuron", + int, + optional=True, + default=4, + doc=doc_axis_neuron, + ), + Argument( + "update_angle", + bool, + optional=True, + default=True, + doc=doc_update_angle, + ), + Argument( + "update_style", + str, + optional=True, + default="res_residual", + doc=doc_update_style, + ), + Argument( + "update_residual", + float, + optional=True, + default=0.1, + doc=doc_update_residual, + ), + Argument( + "update_residual_init", + str, + optional=True, + default="const", + doc=doc_update_residual_init, + ), + ] + + @descrpt_args_plugin.register( "se_a_ebd_v2", alias=["se_a_tpe_v2"], doc=doc_only_tf_supported ) @@ -2188,6 +2354,12 @@ def loss_ener(): doc_relative_f = "If provided, relative force error will be used in the loss. The difference of force will be normalized by the magnitude of the force in the label with a shift given by `relative_f`, i.e. DF_i / ( || F || + relative_f ) with DF denoting the difference between prediction and label and || F || denoting the L2 norm of the label." doc_enable_atom_ener_coeff = "If true, the energy will be computed as \\sum_i c_i E_i. c_i should be provided by file atom_ener_coeff.npy in each data system, otherwise it's 1." return [ + Argument( + "use_l1_all", + bool, + optional=True, + default=False, + ), Argument( "start_pref_e", [float, int], diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py new file mode 100644 index 0000000000..701726a631 --- /dev/null +++ b/source/tests/pt/model/test_dpa3.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt.model.descriptor import ( + DescrptDPA3, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +class TestDescrptDPA3(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + prec, + ect, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal GPU test cases... + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_rate=acr, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA3.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + torch.tensor(self.mapping, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + + def test_jit( + self, + ) -> None: + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for ( + ua, + rus, + ruri, + acr, + prec, + ect, + ) in itertools.product( + [True, False], # update_angle + ["res_residual"], # update_style + ["norm", "const"], # update_residual_init + [0, 1], # a_compress_rate + ["float64"], # precision + [False], # use_econf_tebd + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=10, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_rate=acr, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + ) + + # dpa3 new impl + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + # kwargs for descriptor + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + model = torch.jit.script(dd0) diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 7911cb9395..df13a7ff92 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -17,6 +17,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) from ....consistent.common import ( parameterize_func, @@ -460,6 +463,75 @@ def DescriptorParamDPA2( DescriptorParamDPA2 = DescriptorParamDPA2List[0] +def DescriptorParamDPA3( + ntypes, + rcut, + rcut_smth, + sel, + type_map, + env_protection=0.0, + exclude_types=[], + update_style="res_residual", + update_residual=0.1, + update_residual_init="const", + update_angle=True, + a_compress_rate=0, + precision="float64", +): + input_dict = { + # kwargs for repformer + "repflow": RepFlowArgs( + **{ + "n_dim": 20, + "e_dim": 10, + "a_dim": 10, + "nlayers": 3, + "e_rcut": rcut, + "e_rcut_smth": rcut_smth, + "e_sel": sum(sel), + "a_rcut": rcut / 2, + "a_rcut_smth": rcut_smth / 2, + "a_sel": sum(sel) // 4, + "a_compress_rate": a_compress_rate, + "axis_neuron": 4, + "update_angle": update_angle, + "update_style": update_style, + "update_residual": update_residual, + "update_residual_init": update_residual_init, + } + ), + "ntypes": ntypes, + "concat_output_tebd": False, + "precision": precision, + "activation_function": "silu", + "exclude_types": exclude_types, + "env_protection": env_protection, + "trainable": True, + "use_econf_tebd": False, + "use_tebd_bias": False, + "type_map": type_map, + "seed": GLOBAL_SEED, + } + return input_dict + + +DescriptorParamDPA3List = parameterize_func( + DescriptorParamDPA3, + OrderedDict( + { + "update_residual_init": ("const",), + "exclude_types": ([], [[0, 1]]), + "update_angle": (True, False), + "a_compress_rate": (0, 1), + "env_protection": (0.0, 1e-8), + "precision": ("float64",), + } + ), +) +# to get name for the default function +DescriptorParamDPA3 = DescriptorParamDPA3List[0] + + def DescriptorParamHybrid(ntypes, rcut, rcut_smth, sel, type_map, **kwargs): ddsub0 = { "type": "se_e2_a", diff --git a/source/tests/universal/pt/descriptor/test_descriptor.py b/source/tests/universal/pt/descriptor/test_descriptor.py index 349eb65588..25c78b43c1 100644 --- a/source/tests/universal/pt/descriptor/test_descriptor.py +++ b/source/tests/universal/pt/descriptor/test_descriptor.py @@ -4,6 +4,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -20,6 +21,7 @@ from ...dpmodel.descriptor.test_descriptor import ( DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamSeA, @@ -40,6 +42,7 @@ (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ) # class_param & class diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 3eb1484c45..b86c5ecc40 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -10,6 +10,7 @@ from deepmd.pt.model.descriptor import ( DescrptDPA1, DescrptDPA2, + DescrptDPA3, DescrptHybrid, DescrptSeA, DescrptSeR, @@ -55,6 +56,8 @@ DescriptorParamDPA1List, DescriptorParamDPA2, DescriptorParamDPA2List, + DescriptorParamDPA3, + DescriptorParamDPA3List, DescriptorParamHybrid, DescriptorParamHybridMixed, DescriptorParamHybridMixedTTebd, @@ -93,6 +96,7 @@ DescriptorParamSeTTebd, DescriptorParamDPA1, DescriptorParamDPA2, + DescriptorParamDPA3, DescriptorParamHybrid, DescriptorParamHybridMixed, ] @@ -219,6 +223,7 @@ def setUpClass(cls) -> None: ], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), @@ -233,6 +238,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeTTebd, DescrptSeTTebd), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], @@ -316,6 +322,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -326,6 +333,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, DipoleFittingNet) for param_func in FittingParamDipoleList], @@ -409,6 +417,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -419,6 +428,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, PolarFittingNet) for param_func in FittingParamPolarList], @@ -721,6 +731,7 @@ def setUpClass(cls) -> None: *[(param_func, DescrptSeA) for param_func in DescriptorParamSeAList], *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybrid, DescrptHybrid), (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class @@ -731,6 +742,7 @@ def setUpClass(cls) -> None: (DescriptorParamSeA, DescrptSeA), (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[ @@ -812,6 +824,7 @@ def setUpClass(cls) -> None: ( *[(param_func, DescrptDPA1) for param_func in DescriptorParamDPA1List], *[(param_func, DescrptDPA2) for param_func in DescriptorParamDPA2List], + *[(param_func, DescrptDPA3) for param_func in DescriptorParamDPA3List], (DescriptorParamHybridMixed, DescrptHybrid), (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class @@ -821,6 +834,7 @@ def setUpClass(cls) -> None: ( (DescriptorParamDPA1, DescrptDPA1), (DescriptorParamDPA2, DescrptDPA2), + (DescriptorParamDPA3, DescrptDPA3), ), # descrpt_class_param & class ( *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],