diff --git a/deepmd/model_format/__init__.py b/deepmd/model_format/__init__.py index e15f73758e..3aa28ec192 100644 --- a/deepmd/model_format/__init__.py +++ b/deepmd/model_format/__init__.py @@ -14,6 +14,8 @@ EmbeddingNet, FittingNet, NativeLayer, + EmbdLayer, + LayerNorm, NativeNet, NetworkCollection, load_dp_model, @@ -35,10 +37,14 @@ from .se_e2_a import ( DescrptSeA, ) +from .dpa1 import DescrptDPA1 __all__ = [ "InvarFitting", "DescrptSeA", + "DescrptDPA1", + "EmbdLayer", + "LayerNorm", "EnvMat", "make_multilayer_network", "make_embedding_network", diff --git a/deepmd/model_format/dpa1.py b/deepmd/model_format/dpa1.py new file mode 100644 index 0000000000..829339838f --- /dev/null +++ b/deepmd/model_format/dpa1.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +import copy +from typing import ( + Any, + List, + Optional, +) + +from .common import ( + DEFAULT_PRECISION, + NativeOP, +) +from .env_mat import ( + EnvMat, +) +from .network import ( + EmbeddingNet, + NetworkCollection, + EmbdLayer, +) + + +class DescrptDPA1(NativeOP): + r"""Attention-based descriptor :math:`\mathcal{D}^i \in \mathbb{R}^{M \times M_{<}}`, + which is proposed in pretrainable DPA-1[1] model, is given by + + .. math:: + \mathcal{D}^i = \frac{1}{N_c^2}(\hat{\mathcal{G}}^i)^T \mathcal{R}^i (\mathcal{R}^i)^T \hat{\mathcal{G}}^i_<, + + where :math:`\hat{\mathcal{G}}^i` represents the embedding matrix:math:`\mathcal{G}^i` + after additional self-attention mechanism and :math:`\mathcal{R}^i` is defined by the full case in the se_e2_a descriptor. + Note that we obtain :math:`\mathcal{G}^i` using the type embedding method by default in this descriptor. + + To perform the self-attention mechanism, the queries :math:`\mathcal{Q}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + keys :math:`\mathcal{K}^{i,l} \in \mathbb{R}^{N_c\times d_k}`, + and values :math:`\mathcal{V}^{i,l} \in \mathbb{R}^{N_c\times d_v}` are first obtained: + + .. math:: + \left(\mathcal{Q}^{i,l}\right)_{j}=Q_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{K}^{i,l}\right)_{j}=K_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + .. math:: + \left(\mathcal{V}^{i,l}\right)_{j}=V_{l}\left(\left(\mathcal{G}^{i,l-1}\right)_{j}\right), + + where :math:`Q_{l}`, :math:`K_{l}`, :math:`V_{l}` represent three trainable linear transformations + that output the queries and keys of dimension :math:`d_k` and values of dimension :math:`d_v`, and :math:`l` + is the index of the attention layer. + The input embedding matrix to the attention layers, denoted by :math:`\mathcal{G}^{i,0}`, + is chosen as the two-body embedding matrix. + + Then the scaled dot-product attention method is adopted: + + .. math:: + A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})=\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right)\mathcal{V}^{i,l}, + + where :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) \in \mathbb{R}^{N_c\times N_c}` is attention weights. + In the original attention method, + one typically has :math:`\varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}\right)=\mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right)`, + with :math:`\sqrt{d_{k}}` being the normalization temperature. + This is slightly modified to incorporate the angular information: + + .. math:: + \varphi\left(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l},\mathcal{R}^{i,l}\right) = \mathrm{softmax}\left(\frac{\mathcal{Q}^{i,l} (\mathcal{K}^{i,l})^{T}}{\sqrt{d_{k}}}\right) \odot \hat{\mathcal{R}}^{i}(\hat{\mathcal{R}}^{i})^{T}, + + where :math:`\hat{\mathcal{R}}^{i} \in \mathbb{R}^{N_c\times 3}` denotes normalized relative coordinates, + :math:`\hat{\mathcal{R}}^{i}_{j} = \frac{\boldsymbol{r}_{ij}}{\lVert \boldsymbol{r}_{ij} \lVert}` + and :math:`\odot` means element-wise multiplication. + + Then layer normalization is added in a residual way to finally obtain the self-attention local embedding matrix + :math:`\hat{\mathcal{G}}^{i} = \mathcal{G}^{i,L_a}` after :math:`L_a` attention layers:[^1] + + .. math:: + \mathcal{G}^{i,l} = \mathcal{G}^{i,l-1} + \mathrm{LayerNorm}(A(\mathcal{Q}^{i,l}, \mathcal{K}^{i,l}, \mathcal{V}^{i,l}, \mathcal{R}^{i,l})). + + Parameters + ---------- + rcut + The cut-off radius :math:`r_c` + rcut_smth + From where the environment matrix should be smoothed :math:`r_s` + sel : list[str] + sel[i] specifies the maxmum number of type i atoms in the cut-off radius + ntypes : int + Number of element types + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + axis_neuron + Number of the axis neuron :math:`M_2` (number of columns of the sub-matrix of the embedding matrix) + tebd_dim: int + Dimension of the type embedding + tebd_input_mode: str + The way to mix the type embeddings. Supported options are `concat`, `dot_residual_s`. + resnet_dt + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable + If the weights of embedding net are trainable. + type_one_side + Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + attn: int + Hidden dimension of the attention vectors + attn_layer: int + Number of attention layers + attn_dotr: bool + If dot the angular gate to the attention weights + attn_mask: bool + If mask the diagonal of attention weights + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + set_davg_zero + Set the shift of embedding net input to zero. + activation_function + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + scaling_factor: float + The scaling factor of normalization in calculations of attention weights. + If `temperature` is None, the scaling of attention weights is (N_dim * scaling_factor)**0.5 + temperature: Optional[float] + If not None, the scaling of attention weights is `temperature` itself. + spin + The deepspin object. + + Limitations + ----------- + The currently implementation does not support the following features + + 1. type_one_side == False + 2. exclude_types != [] + 3. spin is not None + 4. tebd_input_mode != 'concat' + 5. smooth == True + + References + ---------- + .. [1] Duo Zhang, Hangrui Bi, Fu-Zhi Dai, Wanrun Jiang, Linfeng Zhang, and Han Wang. 2022. + DPA-1: Pretraining of Attention-based Deep Potential Model for Molecular Simulation. + arXiv preprint arXiv:2208.08236. + """ + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: List[str], + ntypes: int, + neuron: List[int] = [25, 50, 100], + axis_neuron: int = 8, + tebd_dim: int = 8, + tebd_input_mode: str = "concat", + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = True, + attn: int = 128, + attn_layer: int = 2, + attn_dotr: bool = True, + attn_mask: bool = False, + exclude_types: List[List[int]] = [], + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + scaling_factor=1.0, + normalize=True, + temperature=None, + smooth: bool = True, + concat_output_tebd: bool = True, + spin: Optional[Any] = None, + ) -> None: + ## seed, uniform_seed, multi_task, not included. + if not type_one_side: + raise NotImplementedError("type_one_side == False not implemented") + if exclude_types != []: + raise NotImplementedError("exclude_types is not implemented") + if spin is not None: + raise NotImplementedError("spin is not implemented") + # TODO + if tebd_input_mode != 'concat': + raise NotImplementedError("tebd_input_mode != 'concat' not implemented") + if not smooth: + raise NotImplementedError("smooth == False not implemented") + + self.rcut = rcut + self.rcut_smth = rcut_smth + if isinstance(sel, int): + sel = [sel] + self.sel = sel + self.ntypes = ntypes + self.neuron = neuron + self.axis_neuron = axis_neuron + self.tebd_dim = tebd_dim + self.tebd_input_mode = tebd_input_mode + self.resnet_dt = resnet_dt + self.trainable = trainable + self.type_one_side = type_one_side + self.attn = attn + self.attn_layer = attn_layer + self.attn_dotr = attn_dotr + self.attn_mask = attn_mask + self.exclude_types = exclude_types + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.concat_output_tebd = concat_output_tebd + self.spin = spin + + self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision) + in_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ['concat'] else 1 + self.embeddings = NetworkCollection( + ndim=0, + ntypes=self.ntypes, + network_type="embedding_network", + ) + self.embeddings[0] = EmbeddingNet( + in_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + # self.dpa1_attention = NeighborGatedAttention + self.env_mat = EnvMat(self.rcut, self.rcut_smth) + self.nnei = np.sum(self.sel) + self.davg = np.zeros([self.ntypes, self.nnei, 4]) + self.dstd = np.ones([self.ntypes, self.nnei, 4]) + self.orig_sel = self.sel + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.neuron[-1] * self.axis_neuron + self.tebd_dim * 2 \ + if self.concat_output_tebd else self.neuron[-1] * self.axis_neuron + + def cal_g( + self, + ss, + ll, + ): + nf, nloc, nnei = ss.shape[0:3] + ss = ss.reshape(nf, nloc, nnei, -1) + # nf x nloc x nnei x ng + gg = self.embeddings[ll].call(ss) + return gg + + def call( + self, + coord_ext, + atype_ext, + nlist, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + + # nf x nloc x nnei x 4 + rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd) + nf, nloc, nnei, _ = rr.shape + + # add type embedding into input + # nf x nall x tebd_dim + atype_embd_ext = self.type_embedding.call(atype_ext) + atype_embd = atype_embd_ext[:, :nloc, :] + # nf x nloc x nnei x tebd_dim + atype_embd_nnei = np.tile(atype_embd[:, :, np.newaxis, :], (1, 1, nnei, 1)) + nlist_mask = nlist != -1 + nlist_masked = np.copy(nlist) + nlist_masked[nlist_masked == -1] = 0 + index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + # nf x nloc x nnei x tebd_dim + atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape(nf, nloc, nnei, self.tebd_dim) + ng = self.neuron[-1] + ss = rr[..., 0:1] + ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + + # calculate gg + gg = self.cal_g(ss, 0) + # nf x nloc x ng x 4 + gr = np.einsum("flni,flnj->flij", gg, rr) + # nf x nloc x ng x 4 + gr /= self.nnei + gr1 = gr[:, :, : self.axis_neuron, :] + # nf x nloc x ng x ng1 + grrg = np.einsum("flid,fljd->flij", gr, gr1) + # nf x nloc x (ng x ng1) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) + if self.concat_output_tebd: + grrg = np.concatenate([grrg, atype_embd], axis=-1) + return grrg, gr[..., 1:], None, None, ww + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + return { + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "neuron": self.neuron, + "axis_neuron": self.axis_neuron, + "tebd_dim": self.tebd_dim, + "tebd_input_mode": self.tebd_input_mode, + "resnet_dt": self.resnet_dt, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "exclude_types": self.exclude_types, + "set_davg_zero": self.set_davg_zero, + "attn": self.attn, + "attn_layer": self.attn_layer, + "attn_dotr": self.attn_dotr, + "attn_mask": self.attn_mask, + "activation_function": self.activation_function, + "precision": self.precision, + "spin": self.spin, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "concat_output_tebd": self.concat_output_tebd, + "embeddings": self.embeddings.serialize(), + # "attention_layers": self.dpa1_attention.serialize(), + "env_mat": self.env_mat.serialize(), + "type_embedding": self.type_embedding.serialize(), + "@variables": { + "davg": self.davg, + "dstd": self.dstd, + }, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = copy.deepcopy(data) + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.type_embedding = EmbdLayer.deserialize(type_embedding) + obj.embeddings = NetworkCollection.deserialize(embeddings) + obj.env_mat = EnvMat.deserialize(env_mat) + # obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj diff --git a/deepmd/model_format/network.py b/deepmd/model_format/network.py index f2056c0b95..508eb07d56 100644 --- a/deepmd/model_format/network.py +++ b/deepmd/model_format/network.py @@ -322,6 +322,272 @@ def fn(x): return y +class EmbdLayer(NativeLayer): + """Implementation of embedding layer. + + Parameters + ---------- + w : np.ndarray, optional + The embedding weights of the layer. + padding : bool, optional + Whether the embedding layer need to add one padding in the last channel. + """ + + def __init__( + self, + num_channel, + num_out, + padding: bool = True, + precision: str = DEFAULT_PRECISION, + ) -> None: + self.padding = padding + self.num_channel = num_channel + 1 if self.padding else num_channel + super().__init__(num_in=self.num_channel, + num_out=num_out, + bias=False, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) + if self.padding: + self.w[-1] = 0. + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + data = { + "w": self.w + } + return { + "padding": self.padding, + "precision": self.precision, + "@variables": data, + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbdLayer": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = copy.deepcopy(data) + variables = data.pop("@variables") + padding = data.pop("padding") + assert variables["w"] is not None and len(variables["w"].shape) == 2 + num_channel, num_out = variables["w"].shape + obj = cls( + num_channel, + num_out, + padding=False, + **data, + ) + obj.w, = ( + variables["w"], + ) + obj.padding = padding + obj.check_shape_consistency() + return obj + + def __setitem__(self, key, value): + if key in ("w", "matrix"): + self.w = value + elif key == "precision": + self.precision = value + elif key == "padding": + self.padding = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("w", "matrix"): + return self.w + elif key == "precision": + return self.precision + elif key == "padding": + return self.padding + else: + raise KeyError(key) + + def dim_channel(self) -> int: + return self.w.shape[0] + + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + if self.w is None: + raise ValueError("w must be set") + y = np.take(self.w, x, axis=0) + return y + + +class LayerNorm(NativeLayer): + """Implementation of Layer Normalization layer. + + Parameters + ---------- + w : np.ndarray, optional + The learnable weights of the normalization scale in the layer. + b : np.ndarray, optional + The learnable biases of the normalization shift in the layer. + eps : float, optional + A small value added to prevent division by zero in calculations. + uni_init : bool, optional + If initialize the weights to be zeros and ones. + """ + + def __init__( + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + precision: str = DEFAULT_PRECISION, + ) -> None: + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__(num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + ) + self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] + if self.uni_init: + self.w = 1. + self.b = 0. + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + data = { + "w": self.w, + "b": self.b, + } + return { + "eps": self.eps, + "precision": self.precision, + "@variables": data, + } + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = copy.deepcopy(data) + variables = data.pop("@variables") + if variables["w"] is not None: + assert len(variables["w"].shape) == 1 + if variables["b"] is not None: + assert len(variables["b"].shape) == 1 + num_in, = variables["w"].shape + obj = cls( + num_in, + **data, + ) + obj.w, = ( + variables["w"], + ) + obj.b, = ( + variables["b"], + ) + obj._check_shape_consistency() + return obj + + def _check_shape_consistency(self): + if self.b is not None and self.w.shape[0] != self.b.shape[0]: + raise ValueError( + f"dim 1 of w {self.w.shape[0]} is not equal to shape " + f"of b {self.b.shape[0]}", + ) + + def __setitem__(self, key, value): + if key in ("w", "matrix"): + self.w = value + elif key in ("b", "bias"): + self.b = value + elif key == "precision": + self.precision = value + elif key == "eps": + self.eps = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("w", "matrix"): + return self.w + elif key in ("b", "bias"): + return self.b + elif key == "precision": + return self.precision + elif key == "eps": + return self.eps + else: + raise KeyError(key) + + def dim_out(self) -> int: + return self.w.shape[0] + + def call(self, x: np.ndarray) -> np.ndarray: + """Forward pass. + + Parameters + ---------- + x : np.ndarray + The input. + + Returns + ------- + np.ndarray + The output. + """ + if self.w is None or self.b is None: + raise ValueError("w/b must be set") + y = self.layer_norm_numpy(x, tuple((self.num_in,)), self.w, self.b, self.eps) + return y + + @staticmethod + def layer_norm_numpy(x, shape, weight, bias, eps): + # mean and variance + mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + # normalize + x_normalized = (x - mean) / np.sqrt(var + eps) + # shift and scale + x_ln = x_normalized * weight + bias + return x_ln + + def make_multilayer_network(T_NetworkLayer, ModuleBase): class NN(ModuleBase): """Native representation of a neural network. diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 23f521b6d8..611c9b1179 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -13,9 +13,12 @@ TypeEmbedNet, ) -from .se_atten import ( - DescrptBlockSeAtten, +from .se_atten import DescrptBlockSeAtten, NeighborGatedAttention +from deepmd.pt.model.network.mlp import EmbdLayer, NetworkCollection +from deepmd.model_format import ( + EnvMat as DPEnvMat, ) +from deepmd.pt.utils import env @Descriptor.register("dpa1") @@ -37,17 +40,16 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation="tanh", + activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, concat_output_tebd: bool = True, type: Optional[str] = None, + old_impl: bool = False, + **kwargs, ): super().__init__() del type @@ -65,17 +67,22 @@ def __init__( attn_layer=attn_layer, attn_dotr=attn_dotr, attn_mask=attn_mask, - post_ln=post_ln, - ffn=ffn, - ffn_embed_dim=ffn_embed_dim, - activation=activation, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, scaling_factor=scaling_factor, - head_num=head_num, normalize=normalize, temperature=temperature, - return_rot=return_rot, + old_impl=old_impl, + **kwargs ) - self.type_embedding = TypeEmbedNet(ntypes, tebd_dim) + self.type_embedding_old = None + self.type_embedding = None + self.old_impl = old_impl + if self.old_impl: + self.type_embedding_old = TypeEmbedNet(ntypes, tebd_dim) + else: + self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision) self.tebd_dim = tebd_dim self.concat_output_tebd = concat_output_tebd @@ -168,7 +175,12 @@ def forward( del mapping nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 - g1_ext = self.type_embedding(extended_atype) + if self.old_impl: + assert self.type_embedding_old is not None + g1_ext = self.type_embedding_old(extended_atype) + else: + assert self.type_embedding is not None + g1_ext = self.type_embedding(extended_atype) g1_inp = g1_ext[:, :nloc, :] g1, g2, h2, rot_mat, sw = self.se_atten( nlist, @@ -181,3 +193,65 @@ def forward( g1 = torch.cat([g1, g1_inp], dim=-1) return g1, rot_mat, g2, h2, sw + + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.se_atten.mean = mean + self.se_atten.stddev = stddev + + def serialize(self) -> dict: + obj = self.se_atten + return { + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn_dim, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": obj.attn_mask, + "activation_function": obj.activation_function, + "precision": obj.precision, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "concat_output_tebd": self.concat_output_tebd, + "embeddings": obj.filter_layers.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "type_embedding": self.type_embedding.serialize(), + "@variables": { + "davg": obj["davg"].detach().cpu().numpy(), + "dstd": obj["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "type_one_side": True, + "exclude_types": [], + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + type_embedding = data.pop("type_embedding") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + obj = cls(**data) + t_cvt = lambda xx: torch.tensor(xx, dtype=obj.se_atten.prec, device=env.DEVICE) + obj.type_embedding = EmbdLayer.deserialize(type_embedding) + obj.se_atten["davg"] = t_cvt(variables["davg"]) + obj.se_atten["dstd"] = t_cvt(variables["dstd"]) + obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings) + obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3f42736dca..0ece6a61c6 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -451,7 +451,7 @@ def forward( xyz_scatter = torch.zeros( [nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE ) - for ii, ll in enumerate(self.filter_layers.networks): + for ii, ll in enumerate(self.filter_layers._networks): # nfnl x nt x 4 rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] ss = rr[:, :, :1] diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 78cba59da7..ba7c533ff7 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -6,6 +6,13 @@ import numpy as np import torch +import torch.nn as nn +import torch.nn.functional as torch_func +from deepmd.pt.utils.env import ( + PRECISION_DICT, + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ActivationFn from deepmd.pt.model.descriptor.descriptor import ( DescriptorBlock, @@ -18,6 +25,11 @@ NeighborWiseAttention, TypeFilter, ) +from deepmd.pt.model.network.mlp import EmbeddingNet, NetworkCollection, MLPLayer, LayerNorm + +from deepmd.model_format import ( + EnvMat as DPEnvMat, +) from deepmd.pt.utils import ( env, ) @@ -34,23 +46,22 @@ def __init__( neuron: list = [25, 50, 100], axis_neuron: int = 16, tebd_dim: int = 8, - tebd_input_mode: str = "concat", + tebd_input_mode: str = 'concat', # set_davg_zero: bool = False, set_davg_zero: bool = True, # TODO attn: int = 128, attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - post_ln=True, - ffn=False, - ffn_embed_dim=1024, - activation="tanh", + activation_function="tanh", + precision: str = "float64", + resnet_dt: bool = False, scaling_factor=1.0, - head_num=1, normalize=True, temperature=None, - return_rot=False, type: Optional[str] = None, + old_impl: bool = False, + **kwargs, ): """Construct an embedding net of type `se_atten`. @@ -65,6 +76,8 @@ def __init__( del type self.rcut = rcut self.rcut_smth = rcut_smth + self.neuron = neuron + self.filter_neuron = self.neuron self.filter_neuron = neuron self.axis_neuron = axis_neuron self.tebd_dim = tebd_dim @@ -74,15 +87,14 @@ def __init__( self.attn_layer = attn_layer self.attn_dotr = attn_dotr self.attn_mask = attn_mask - self.post_ln = post_ln - self.ffn = ffn - self.ffn_embed_dim = ffn_embed_dim - self.activation = activation + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt self.scaling_factor = scaling_factor - self.head_num = head_num self.normalize = normalize self.temperature = temperature - self.return_rot = return_rot + self.old_impl = old_impl if isinstance(sel, int): sel = [sel] @@ -93,22 +105,24 @@ def __init__( self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 - self.dpa1_attention = NeighborWiseAttention( - self.attn_layer, - self.nnei, - self.filter_neuron[-1], - self.attn_dim, - dotr=self.attn_dotr, - do_mask=self.attn_mask, - post_ln=self.post_ln, - ffn=self.ffn, - ffn_embed_dim=self.ffn_embed_dim, - activation=self.activation, - scaling_factor=self.scaling_factor, - head_num=self.head_num, - normalize=self.normalize, - temperature=self.temperature, - ) + if self.old_impl: + self.dpa1_attention = NeighborWiseAttention(self.attn_layer, self.nnei, self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, do_mask=self.attn_mask, + activation=self.activation_function, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature) + else: + self.dpa1_attention = NeighborGatedAttention(self.attn_layer, + self.nnei, + self.filter_neuron[-1], + self.attn_dim, + dotr=self.attn_dotr, + do_mask=self.attn_mask, + scaling_factor=self.scaling_factor, + normalize=self.normalize, + temperature=self.temperature) wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros( @@ -119,19 +133,26 @@ def __init__( ) self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + self.embd_input_dim = 1 + self.tebd_dim * 2 if self.tebd_input_mode in ['concat'] else 1 + self.filter_layers_old = None + self.filter_layers = None - filter_layers = [] - one = TypeFilter( - 0, - self.nnei, - self.filter_neuron, - return_G=True, - tebd_dim=self.tebd_dim, - use_tebd=True, - tebd_mode=self.tebd_input_mode, - ) - filter_layers.append(one) - self.filter_layers = torch.nn.ModuleList(filter_layers) + if self.old_impl: + filter_layers = [] + one = TypeFilter(0, self.nnei, self.filter_neuron, return_G=True, tebd_dim=self.tebd_dim, use_tebd=True, + tebd_mode=self.tebd_input_mode) + filter_layers.append(one) + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + else: + filter_layers = NetworkCollection(ndim=0, ntypes=len(sel), network_type="embedding_network") + filter_layers[0] = EmbeddingNet( + self.embd_input_dim, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -172,6 +193,22 @@ def dim_emb(self): """Returns the output dimension of embedding.""" return self.filter_neuron[-1] + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + def compute_input_stats(self, merged): """Update mean and stddev for descriptor elements.""" sumr = [] @@ -282,7 +319,6 @@ def forward( self.rcut_smth, ) # [nfxnlocxnnei, self.ndescrpt] - dmatrix = dmatrix.view(-1, self.ndescrpt) nlist_mask = nlist != -1 nlist[nlist == -1] = 0 sw = torch.squeeze(sw, -1) @@ -300,23 +336,39 @@ def forward( atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) - ret = self.filter_layers[0]( - dmatrix, - atype_tebd=atype_tebd_nnei, - nlist_tebd=atype_tebd_nlist, - ) # shape is [nframes*nall, self.neei, out_size] - input_r = torch.nn.functional.normalize( - dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - ret = self.dpa1_attention( - ret, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute( - 0, 2, 1 - ) # shape is [nframes*natoms[0], 4, self.neei] - xyz_scatter = torch.matmul( - inputs_reshape, ret - ) # shape is [nframes*natoms[0], 4, out_size] + if self.old_impl: + assert self.filter_layers_old is not None + dmatrix = dmatrix.view(-1, self.ndescrpt) # shape is [nframes*nall, self.ndescrpt] + gg = self.filter_layers_old[0]( + dmatrix, + atype_tebd=atype_tebd_nnei, + nlist_tebd=atype_tebd_nlist, + ) # shape is [nframes*nall, self.neei, out_size] + input_r = torch.nn.functional.normalize(dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1) + gg = self.dpa1_attention(gg, nlist_mask, input_r=input_r, + sw=sw) # shape is [nframes*nloc, self.neei, out_size] + inputs_reshape = dmatrix.view(-1, self.nnei, 4).permute(0, 2, + 1) # shape is [nframes*natoms[0], 4, self.neei] + xyz_scatter = torch.matmul(inputs_reshape, gg) # shape is [nframes*natoms[0], 4, out_size] + else: + assert self.filter_layers is not None + dmatrix = dmatrix.view(-1, self.nnei, 4) + nfnl = dmatrix.shape[0] + # nfnl x nnei x 4 + rr = dmatrix + ss = rr[:, :, :1] + if self.tebd_input_mode in ['concat']: + nlist_tebd = atype_tebd_nlist.reshape(nfnl, nnei, self.tebd_dim) + atype_tebd = atype_tebd_nnei.reshape(nfnl, nnei, self.tebd_dim) + # nfnl x nnei x (1 + tebd_dim * 2) + ss = torch.concat([ss, nlist_tebd, atype_tebd], dim=2) + # nfnl x nnei x ng + gg = self.filter_layers._networks[0](ss) + input_r = torch.nn.functional.normalize(dmatrix.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1) + gg = self.dpa1_attention(gg, nlist_mask, input_r=input_r, + sw=sw) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] @@ -326,13 +378,354 @@ def forward( ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] return ( result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron), - ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]), + gg.view(-1, nloc, self.nnei, self.filter_neuron[-1]), dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:], rot_mat.view(-1, self.filter_neuron[-1], 3), sw, ) +class NeighborGatedAttention(nn.Module): + def __init__(self, + layer_num: int, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net. + """ + super(NeighborGatedAttention, self).__init__() + self.layer_num = layer_num + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.network_type = NeighborGatedAttentionLayer + attention_layers = [] + for i in range(self.layer_num): + attention_layers.append(NeighborGatedAttentionLayer(nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + precision=precision)) + self.attention_layers = nn.ModuleList(attention_layers) + + def forward( + self, + input_G, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + """ + Args: + input_G: Input G, [nframes * nloc, nnei, embed_dim] + nei_mask: neighbor mask, [nframes * nloc, nnei] + input_r: normalized radial, [nframes, nloc, nei, 3] + Returns: + out: Output G, [nframes * nloc, nnei, embed_dim] + """ + out = input_G + # https://github.com/pytorch/pytorch/issues/39165#issuecomment-635472592 + for layer in self.attention_layers: + out = layer(out, nei_mask, input_r=input_r, sw=sw) + return out + + def _convert_key(self, key): + if isinstance(key, int): + idx = key + else: + if isinstance(key, tuple): + pass + elif isinstance(key, str): + key = tuple([int(tt) for tt in key.split("_")[1:]]) + else: + raise TypeError(key) + assert isinstance(key, tuple) + assert len(key) == self.ndim + idx = sum([tt * self.ntypes**ii for ii, tt in enumerate(key)]) + return idx + + def __getitem__(self, key): + return self.attention_layers[self._convert_key(key)] + + def __setitem__(self, key, value): + if isinstance(value, self.network_type): + pass + elif isinstance(value, dict): + value = self.network_type.deserialize(value) + else: + raise TypeError(value) + self.attention_layers[self._convert_key(key)] = value + + def serialize(self) -> dict: + """Serialize the networks to a dict. + Returns + ------- + dict + The serialized networks. + """ + # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} + # network_type_name = network_type_map_inv[self.network_type] + return { + "layer_num": self.layer_num, + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layers": [layer.serialize() for layer in self.attention_layers] + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttention": + """Deserialize the networks from a dict. + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + attention_layers = data.pop("attention_layers") + obj = cls(**data) + for ii, network in enumerate(attention_layers): + obj[ii] = network + return obj + + +class NeighborGatedAttentionLayer(nn.Module): + def __init__(self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention layer. + """ + super(NeighborGatedAttentionLayer, self).__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.scaling_factor = scaling_factor + self.normalize = normalize + self.temperature = temperature + self.precision = precision + self.attention_layer = GatedAttentionLayer(nnei, + embed_dim, + hidden_dim, + dotr=dotr, + do_mask=do_mask, + scaling_factor=scaling_factor, + normalize=normalize, + temperature=temperature, + precision=precision, + ) + self.attn_layer_norm = LayerNorm(self.embed_dim, precision=precision) + + def forward( + self, + x, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + ): + residual = x + x = self.attention_layer(x, nei_mask, input_r=input_r, sw=sw) + x = residual + x + x = self.attn_layer_norm(x) + return x + + def serialize(self) -> dict: + """Serialize the networks to a dict. + Returns + ------- + dict + The serialized networks. + """ + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "precision": self.precision, + "attention_layer": self.attention_layer.serialize(), + "attn_layer_norm": self.attn_layer_norm.serialize() + } + + @classmethod + def deserialize(cls, data: dict) -> "NeighborGatedAttentionLayer": + """Deserialize the networks from a dict. + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + attention_layer = data.pop("attention_layer") + attn_layer_norm = data.pop("attn_layer_norm") + obj = cls(**data) + obj.attention_layer = GatedAttentionLayer.deserialize(attention_layer) + obj.attn_layer_norm = LayerNorm.deserialize(attn_layer_norm) + return obj + + +class GatedAttentionLayer(nn.Module): + def __init__(self, + nnei: int, + embed_dim: int, + hidden_dim: int, + dotr: bool = False, + do_mask: bool = False, + scaling_factor: float = 1.0, + normalize: bool = True, + temperature: float = None, + bias: bool = True, + smooth: bool = True, + precision: str = DEFAULT_PRECISION, + ): + """Construct a neighbor-wise attention net. + """ + super(GatedAttentionLayer, self).__init__() + self.nnei = nnei + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + self.dotr = dotr + self.do_mask = do_mask + self.bias = bias + self.smooth = smooth + self.scaling_factor = scaling_factor + self.temperature = temperature + self.precision = precision + if temperature is None: + self.scaling = (self.hidden_dim * scaling_factor) ** -0.5 + else: + self.scaling = temperature + self.normalize = normalize + self.in_proj = MLPLayer(embed_dim, hidden_dim * 3, bias=bias, use_timestep=False, bavg=0., stddev=1., + precision=precision) + self.out_proj = MLPLayer(hidden_dim, embed_dim, bias=bias, use_timestep=False, bavg=0., stddev=1., + precision=precision) + + def forward( + self, + query, + nei_mask, + input_r: Optional[torch.Tensor] = None, + sw: Optional[torch.Tensor] = None, + attnw_shift: float = 20.0, + ): + """ + Args: + query: input G, [nframes * nloc, nnei, embed_dim] + nei_mask: neighbor mask, [nframes * nloc, nnei] + input_r: normalized radial, [nframes, nloc, nei, 3] + Returns: + type_embedding: + """ + q, k, v = self.in_proj(query).chunk(3, dim=-1) + # [nframes * nloc, nnei, hidden_dim] + q = q.view(-1, self.nnei, self.hidden_dim) + k = k.view(-1, self.nnei, self.hidden_dim) + v = v.view(-1, self.nnei, self.hidden_dim) + if self.normalize: + q = torch_func.normalize(q, dim=-1) + k = torch_func.normalize(k, dim=-1) + v = torch_func.normalize(v, dim=-1) + q = q * self.scaling + k = k.transpose(1, 2) + # [nframes * nloc, nnei, nnei] + attn_weights = torch.bmm(q, k) + # [nframes * nloc, nnei] + nei_mask = nei_mask.view(-1, self.nnei) + if self.smooth: + # [nframes * nloc, nnei] + assert sw is not None + sw = sw.view([-1, self.nnei]) + attn_weights = (attn_weights + attnw_shift) * sw[:, :, None] * sw[:, None, :] - attnw_shift + else: + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(1), float("-inf")) + attn_weights = torch_func.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.masked_fill(~nei_mask.unsqueeze(-1), float(0.0)) + if self.smooth: + assert sw is not None + attn_weights = attn_weights * sw[:, :, None] * sw[:, None, :] + if self.dotr: + assert input_r is not None, "input_r must be provided when dotr is True!" + angular_weight = torch.bmm(input_r, input_r.transpose(1, 2)) + attn_weights = attn_weights * angular_weight + o = torch.bmm(attn_weights, v) + output = self.out_proj(o) + return output + + def serialize(self) -> dict: + """Serialize the networks to a dict. + Returns + ------- + dict + The serialized networks. + """ + # network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} + # network_type_name = network_type_map_inv[self.network_type] + return { + "nnei": self.nnei, + "embed_dim": self.embed_dim, + "hidden_dim": self.hidden_dim, + "dotr": self.dotr, + "do_mask": self.do_mask, + "scaling_factor": self.scaling_factor, + "normalize": self.normalize, + "temperature": self.temperature, + "bias": self.bias, + "smooth": self.smooth, + "precision": self.precision, + "in_proj": self.in_proj.serialize(), + "out_proj": self.out_proj.serialize() + } + + @classmethod + def deserialize(cls, data: dict) -> "GatedAttentionLayer": + """Deserialize the networks from a dict. + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + in_proj = data.pop("in_proj") + out_proj = data.pop("out_proj") + obj = cls(**data) + obj.in_proj = MLPLayer.deserialize(in_proj) + obj.out_proj = MLPLayer.deserialize(out_proj) + return obj + + def analyze_descrpt(matrix, ndescrpt, natoms, mixed_type=False, real_atype=None): """Collect avg, square avg and count of descriptors in a batch.""" ntypes = natoms.shape[1] - 2 diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index d76abd82f9..5bd9c1a23d 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -8,6 +8,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as torch_func from deepmd.pt.utils import ( env, @@ -18,6 +19,8 @@ from deepmd.model_format import ( NativeLayer, ) +from deepmd.model_format import EmbdLayer as DPEmbdLayer +from deepmd.model_format import LayerNorm as DPLayerNorm from deepmd.model_format import NetworkCollection as DPNetworkCollection from deepmd.model_format import ( make_embedding_network, @@ -188,6 +191,188 @@ def check_load_param(ss): return obj +class EmbdLayer(MLPLayer): + def __init__( + self, + num_channel, + num_out, + padding: bool = True, + stddev: float = 1., + precision: str = DEFAULT_PRECISION, + ): + self.padding = padding + self.num_channel = num_channel + 1 if self.padding else num_channel + super().__init__(num_in=self.num_channel, + num_out=num_out, + bias=False, + use_timestep=False, + activation_function=None, + resnet=False, + stddev=stddev, + precision=precision, + ) + if self.padding: + nn.init.zeros_(self.matrix.data[-1]) + + def dim_channel(self) -> int: + return self.matrix.shape[0] + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """One Embedding layer used by DP model. + + Parameters + ---------- + xx: torch.Tensor + The input of index. + + Returns + ------- + yy: torch.Tensor + The output. + """ + yy = torch_func.embedding(xx, self.matrix) + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = DPEmbdLayer( + self.matrix.shape[0], + self.matrix.shape[1], + padding=False, + precision=self.precision, + ) + nl.w = self.matrix.detach().cpu().numpy() + data = nl.serialize() + data["padding"] = self.padding + return data + + @classmethod + def deserialize(cls, data: dict) -> "EmbdLayer": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + padding = data["padding"] + nl = DPEmbdLayer.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + nl["matrix"].shape[1], + padding=False, + precision=nl["precision"], + ) + obj.padding = padding + prec = PRECISION_DICT[obj.precision] + check_load_param = \ + lambda ss: nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) \ + if nl[ss] is not None else None + obj.matrix = check_load_param("matrix") + return obj + + +class LayerNorm(MLPLayer): + def __init__( + self, + num_in, + eps: float = 1e-5, + uni_init: bool = True, + bavg: float = 0., + stddev: float = 1., + precision: str = DEFAULT_PRECISION, + ): + self.eps = eps + self.uni_init = uni_init + self.num_in = num_in + super().__init__(num_in=1, + num_out=num_in, + bias=True, + use_timestep=False, + activation_function=None, + resnet=False, + bavg=bavg, + stddev=stddev, + precision=precision, + ) + self.matrix = torch.nn.Parameter(self.matrix.squeeze(0)) + if self.uni_init: + nn.init.ones_(self.matrix.data) + nn.init.zeros_(self.bias.data) + + def dim_out(self) -> int: + return self.matrix.shape[0] + + def forward( + self, + xx: torch.Tensor, + ) -> torch.Tensor: + """One Layer Norm used by DP model. + + Parameters + ---------- + xx: torch.Tensor + The input of index. + + Returns + ------- + yy: torch.Tensor + The output. + """ + yy = torch_func.layer_norm(xx, tuple((self.num_in,)), self.matrix, self.bias, self.eps) + return yy + + def serialize(self) -> dict: + """Serialize the layer to a dict. + + Returns + ------- + dict + The serialized layer. + """ + nl = DPLayerNorm( + self.matrix.shape[0], + eps=self.eps, + precision=self.precision, + ) + nl.w = self.matrix.detach().cpu().numpy() + nl.b = self.bias.detach().cpu().numpy() + data = nl.serialize() + return data + + @classmethod + def deserialize(cls, data: dict) -> "LayerNorm": + """Deserialize the layer from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + nl = DPLayerNorm.deserialize(data) + obj = cls( + nl["matrix"].shape[0], + eps=nl["eps"], + precision=nl["precision"], + ) + prec = PRECISION_DICT[obj.precision] + check_load_param = \ + lambda ss: nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) \ + if nl[ss] is not None else None + obj.matrix = check_load_param("matrix") + obj.bias = check_load_param("bias") + return obj + + MLP_ = make_multilayer_network(MLPLayer, nn.Module) @@ -217,4 +402,4 @@ def __init__(self, *args, **kwargs): # init both two base classes DPNetworkCollection.__init__(self, *args, **kwargs) nn.Module.__init__(self) - self.networks = self._networks = torch.nn.ModuleList(self._networks) + self._networks = torch.nn.ModuleList(self._networks) diff --git a/source/tests/pt/models/dpa1.pth b/source/tests/pt/models/dpa1.pt similarity index 52% rename from source/tests/pt/models/dpa1.pth rename to source/tests/pt/models/dpa1.pt index 75acf2fa15..74f69b4b6e 100644 Binary files a/source/tests/pt/models/dpa1.pth and b/source/tests/pt/models/dpa1.pt differ diff --git a/source/tests/pt/models/dpa2.pth b/source/tests/pt/models/dpa2.pt similarity index 50% rename from source/tests/pt/models/dpa2.pth rename to source/tests/pt/models/dpa2.pt index 0559d30c48..ac0b02379a 100644 Binary files a/source/tests/pt/models/dpa2.pth and b/source/tests/pt/models/dpa2.pt differ diff --git a/source/tests/pt/models/dpa2_tebd.pt b/source/tests/pt/models/dpa2_tebd.pt new file mode 100644 index 0000000000..fa84a9b5fa Binary files /dev/null and b/source/tests/pt/models/dpa2_tebd.pt differ diff --git a/source/tests/pt/models/dpa2_tebd.pth b/source/tests/pt/models/dpa2_tebd.pth deleted file mode 100644 index 3d4fc5511c..0000000000 Binary files a/source/tests/pt/models/dpa2_tebd.pth and /dev/null differ diff --git a/source/tests/pt/test_descriptor_dpa1.py b/source/tests/pt/test_descriptor_dpa1.py index 725369d68d..2caeb5890e 100644 --- a/source/tests/pt/test_descriptor_dpa1.py +++ b/source/tests/pt/test_descriptor_dpa1.py @@ -12,9 +12,7 @@ DescrptBlockSeAtten, DescrptDPA1, ) -from deepmd.pt.model.network.network import ( - TypeEmbedNet, -) +from deepmd.pt.model.network.mlp import EmbdLayer from deepmd.pt.utils import ( env, ) @@ -231,8 +229,8 @@ def setUp(self): ).to(env.DEVICE) with open(Path(CUR_DIR) / "models" / "dpa1.json") as fp: self.model_json = json.load(fp) - self.file_model_param = Path(CUR_DIR) / "models" / "dpa1.pth" - self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pth" + self.file_model_param = Path(CUR_DIR) / "models" / "dpa1.pt" + self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pt" def test_descriptor_block(self): # torch.manual_seed(0) @@ -260,7 +258,7 @@ def test_descriptor_block(self): extended_coord, extended_atype, nloc, rcut, nsel, distinguish_types=False ) # handel type_embedding - type_embedding = TypeEmbedNet(ntypes, 8).to(env.DEVICE) + type_embedding = EmbdLayer(ntypes, 8, padding=True) type_embedding.load_state_dict(torch.load(self.file_type_embed)) ## to save model parameters diff --git a/source/tests/pt/test_descriptor_dpa2.py b/source/tests/pt/test_descriptor_dpa2.py index aa6b16964e..22371f612d 100644 --- a/source/tests/pt/test_descriptor_dpa2.py +++ b/source/tests/pt/test_descriptor_dpa2.py @@ -111,7 +111,7 @@ def setUp(self): ).to(env.DEVICE) with open(Path(CUR_DIR) / "models" / "dpa2_hyb.json") as fp: self.model_json = json.load(fp) - self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pth" + self.file_model_param = Path(CUR_DIR) / "models" / "dpa2.pt" self.file_type_embed = Path(CUR_DIR) / "models" / "dpa2_tebd.pth" def test_descriptor_hyb(self): diff --git a/source/tests/pt/test_dpa1.py b/source/tests/pt/test_dpa1.py new file mode 100644 index 0000000000..400a611c05 --- /dev/null +++ b/source/tests/pt/test_dpa1.py @@ -0,0 +1,153 @@ +import torch, copy +import unittest +import itertools +import numpy as np + +try: + from deepmd.model_format import ( + DescrptDPA1 as DPDescrptDPA1 + ) + support_se_atten = True +except ModuleNotFoundError: + support_se_atten = False +except ImportError: + support_se_atten = False + +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1 +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + +@unittest.skipIf(not support_se_atten, "EnvMat not supported") +class TestDescrptSeAtten(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng(100) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # dpa1 new impl + dd0 = DescrptDPA1( + self.rcut, self.rcut_smth, self.sel, self.nt, attn_layer=0, # TODO add support for non-zero layer + # precision=prec, + # resnet_dt=idt, + old_impl=False, + ).to(env.DEVICE) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptDPA1.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy(), + rtol=rtol, atol=atol, err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd2, + rtol=rtol, atol=atol, err_msg=err_msg, + ) + # old impl + if idt is False and prec == "float64": + dd3 = DescrptDPA1( + self.rcut, self.rcut_smth, self.sel, self.nt, attn_layer=0, # TODO add support for non-zero layer + # precision=prec, + # resnet_dt=idt, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.se_atten.state_dict() + dd3_state_dict = dd3.se_atten.state_dict() + + dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() + dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() + for i in dd3_state_dict: + dd3_state_dict[i] = dd0_state_dict[i.replace('.deep_layers.', '.layers.') + .replace('filter_layers_old.', 'filter_layers._networks.').replace('.attn_layer_norm.weight', '.attn_layer_norm.matrix')].detach().clone() + if '.bias' in i and 'attn_layer_norm' not in i: + dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) + dd3.se_atten.load_state_dict(dd3_state_dict) + + dd0_state_dict_tebd = dd0.type_embedding.state_dict() + dd3_state_dict_tebd = dd3.type_embedding_old.state_dict() + for i in dd3_state_dict_tebd: + dd3_state_dict_tebd[i] = dd0_state_dict_tebd[i.replace('embedding.weight', 'matrix')].detach().clone() + dd3.type_embedding_old.load_state_dict(dd3_state_dict_tebd) + + rd3, _, _, _, _ = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd3.detach().cpu().numpy(), + rtol=rtol, atol=atol, err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptDPA1( + self.rcut, self.rcut_smth, self.sel, self.nt, + # precision=prec, + # resnet_dt=idt, + old_impl=False, + ) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + # dd1 = DescrptDPA1.deserialize(dd0.serialize()) + model = torch.jit.script(dd0) + # model = torch.jit.script(dd1) \ No newline at end of file diff --git a/source/tests/pt/test_model.py b/source/tests/pt/test_model.py index e87a53969c..e61310071c 100644 --- a/source/tests/pt/test_model.py +++ b/source/tests/pt/test_model.py @@ -59,7 +59,7 @@ def torch2tf(torch_name, last_layer_id=None): fields = torch_name.split(".") - offset = int(fields[2] == "networks") + offset = int(fields[2] == "_networks") element_id = int(fields[2 + offset]) if fields[0] == "descriptor": layer_id = int(fields[4 + offset]) + 1 diff --git a/source/tests/pt/test_se_e2_a.py b/source/tests/pt/test_se_e2_a.py index 0da80ea1ea..5ff0a1d32e 100644 --- a/source/tests/pt/test_se_e2_a.py +++ b/source/tests/pt/test_se_e2_a.py @@ -118,7 +118,7 @@ def test_consistency( dd3_state_dict[i] = ( dd0_state_dict[ i.replace(".deep_layers.", ".layers.").replace( - "filter_layers_old.", "filter_layers.networks." + "filter_layers_old.", "filter_layers._networks." ) ] .detach()