diff --git a/deepmd/dpmodel/descriptor/__init__.py b/deepmd/dpmodel/descriptor/__init__.py index 5eca26acc5..08f8eb5052 100644 --- a/deepmd/dpmodel/descriptor/__init__.py +++ b/deepmd/dpmodel/descriptor/__init__.py @@ -5,8 +5,12 @@ from .se_e2_a import ( DescrptSeA, ) +from .se_r import ( + DescrptSeR, +) __all__ = [ "DescrptSeA", + "DescrptSeR", "make_base_descriptor", ] diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 97ab719c62..a28215c35a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -3,6 +3,9 @@ import numpy as np +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) from deepmd.utils.path import ( DPPath, ) @@ -183,8 +186,12 @@ def __init__( ) self.env_mat = EnvMat(self.rcut, self.rcut_smth) self.nnei = np.sum(self.sel) - self.davg = np.zeros([self.ntypes, self.nnei, 4]) - self.dstd = np.ones([self.ntypes, self.nnei, 4]) + self.davg = np.zeros( + [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] + ) + self.dstd = np.ones( + [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] + ) self.orig_sel = self.sel def __setitem__(self, key, value): @@ -292,7 +299,7 @@ def call( sec = np.append([0], np.cumsum(self.sel)) ng = self.neuron[-1] - gr = np.zeros([nf * nloc, ng, 4]) + gr = np.zeros([nf * nloc, ng, 4], dtype=PRECISION_DICT[self.precision]) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames @@ -322,7 +329,9 @@ def call( # nf x nloc x ng x ng1 grrg = np.einsum("flid,fljd->flij", gr, gr1) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) + grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( + GLOBAL_NP_FLOAT_PRECISION + ) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py new file mode 100644 index 0000000000..77e43f7d85 --- /dev/null +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -0,0 +1,321 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np + +from deepmd.utils.path import ( + DPPath, +) + +try: + from deepmd._version import version as __version__ +except ImportError: + __version__ = "unknown" + +import copy +from typing import ( + Any, + List, + Optional, +) + +from deepmd.dpmodel import ( + DEFAULT_PRECISION, + PRECISION_DICT, + NativeOP, +) +from deepmd.dpmodel.utils import ( + EmbeddingNet, + EnvMat, + NetworkCollection, + PairExcludeMask, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from .base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") +class DescrptSeR(NativeOP, BaseDescriptor): + r"""DeepPot-SE_R constructed from only the radial imformation of atomic configurations. + + + Parameters + ---------- + rcut + The cut-off radius :math:`r_c` + rcut_smth + From where the environment matrix should be smoothed :math:`r_s` + sel : list[str] + sel[i] specifies the maxmum number of type i atoms in the cut-off radius + neuron : list[int] + Number of neurons in each hidden layers of the embedding net :math:`\mathcal{N}` + resnet_dt + Time-step `dt` in the resnet construction: + y = x + dt * \phi (Wx + b) + trainable + If the weights of embedding net are trainable. + type_one_side + Try to build N_types embedding nets. Otherwise, building N_types^2 embedding nets + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + set_davg_zero + Set the shift of embedding net input to zero. + activation_function + The activation function in the embedding net. Supported options are |ACTIVATION_FN| + precision + The precision of the embedding net parameters. Supported options are |PRECISION| + multi_task + If the model has multi fitting nets to train. + spin + The deepspin object. + + Limitations + ----------- + The currently implementation does not support the following features + + 1. type_one_side == False + 2. exclude_types != [] + 3. spin is not None + + References + ---------- + .. [1] Linfeng Zhang, Jiequn Han, Han Wang, Wissam A. Saidi, Roberto Car, and E. Weinan. 2018. + End-to-end symmetry preserving inter-atomic potential energy model for finite and extended + systems. In Proceedings of the 32nd International Conference on Neural Information Processing + Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 4441-4451. + """ + + def __init__( + self, + rcut: float, + rcut_smth: float, + sel: List[str], + neuron: List[int] = [24, 48, 96], + resnet_dt: bool = False, + trainable: bool = True, + type_one_side: bool = True, + exclude_types: List[List[int]] = [], + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, + ) -> None: + ## seed, uniform_seed, multi_task, not included. + if not type_one_side: + raise NotImplementedError("type_one_side == False not implemented") + if spin is not None: + raise NotImplementedError("spin is not implemented") + + self.rcut = rcut + self.rcut_smth = rcut_smth + self.sel = sel + self.ntypes = len(self.sel) + self.neuron = neuron + self.resnet_dt = resnet_dt + self.trainable = trainable + self.type_one_side = type_one_side + self.exclude_types = exclude_types + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.spin = spin + self.emask = PairExcludeMask(self.ntypes, self.exclude_types) + + in_dim = 1 # not considiering type embedding + self.embeddings = NetworkCollection( + ntypes=self.ntypes, + ndim=(1 if self.type_one_side else 2), + network_type="embedding_network", + ) + if not self.type_one_side: + raise NotImplementedError("type_one_side == False not implemented") + for ii in range(self.ntypes): + self.embeddings[(ii,)] = EmbeddingNet( + in_dim, + self.neuron, + self.activation_function, + self.resnet_dt, + self.precision, + ) + self.env_mat = EnvMat(self.rcut, self.rcut_smth) + self.nnei = np.sum(self.sel) + self.davg = np.zeros( + [self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision] + ) + self.dstd = np.ones( + [self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision] + ) + self.orig_sel = self.sel + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.davg = value + elif key in ("std", "data_std", "dstd"): + self.dstd = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.davg + elif key in ("std", "data_std", "dstd"): + return self.dstd + else: + raise KeyError(key) + + @property + def dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.get_dim_out() + + def get_dim_out(self): + """Returns the output dimension of this descriptor.""" + return self.neuron[-1] + + def get_dim_emb(self): + """Returns the embedding (g2) dimension of this descriptor.""" + raise NotImplementedError + + def get_rcut(self): + """Returns cutoff radius.""" + return self.rcut + + def get_sel(self): + """Returns cutoff radius.""" + return self.sel + + def mixed_types(self): + """Returns if the descriptor requires a neighbor list that distinguish different + atomic types or not. + """ + return False + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + raise NotImplementedError + + def cal_g( + self, + ss, + ll, + ): + nf, nloc, nnei = ss.shape[0:3] + ss = ss.reshape(nf, nloc, nnei, 1) + # nf x nloc x nnei x ng + gg = self.embeddings[(ll,)].call(ss) + return gg + + def call( + self, + coord_ext, + atype_ext, + nlist, + mapping: Optional[np.ndarray] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to lcoal region. not used by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + """ + del mapping + # nf x nloc x nnei x 1 + rr, ww = self.env_mat.call( + coord_ext, atype_ext, nlist, self.davg, self.dstd, True + ) + nf, nloc, nnei, _ = rr.shape + sec = np.append([0], np.cumsum(self.sel)) + + ng = self.neuron[-1] + xyz_scatter = np.zeros([nf, nloc, ng], dtype=PRECISION_DICT[self.precision]) + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + for tt in range(self.ntypes): + mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] + tr = rr[:, :, sec[tt] : sec[tt + 1], :] + tr = tr * mm[:, :, :, None] + gg = self.cal_g(tr, tt) + gg = np.mean(gg, axis=2) + # nf x nloc x ng x 1 + xyz_scatter += gg + + res_rescale = 1.0 / 10.0 + res = xyz_scatter * res_rescale + res = res.reshape(nf, nloc, -1).astype(GLOBAL_NP_FLOAT_PRECISION) + return res, None, None, None, ww + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + return { + "@class": "Descriptor", + "type": "se_r", + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "exclude_types": self.exclude_types, + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "spin": self.spin, + "env_mat": self.env_mat.serialize(), + "embeddings": self.embeddings.serialize(), + "@variables": { + "davg": self.davg, + "dstd": self.dstd, + }, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeR": + """Deserialize from dict.""" + data = copy.deepcopy(data) + data.pop("@class", None) + data.pop("type", None) + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + obj = cls(**data) + + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.embeddings = NetworkCollection.deserialize(embeddings) + obj.env_mat = EnvMat.deserialize(env_mat) + return obj diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 070b0e1549..0e861d9f38 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -30,6 +30,7 @@ def _make_env_mat( coord, rcut: float, ruct_smth: float, + radial_only: bool = False, ): """Make smooth environment matrix.""" nf, nloc, nnei = nlist.shape @@ -54,8 +55,11 @@ def _make_env_mat( t1 = diff / length**2 weight = compute_smooth_weight(length, ruct_smth, rcut) weight = weight * np.expand_dims(mask, -1) - env_mat_se_a = np.concatenate([t0, t1], axis=-1) * weight - return env_mat_se_a, diff * np.expand_dims(mask, -1), weight + if radial_only: + env_mat = t0 * weight + else: + env_mat = np.concatenate([t0, t1], axis=-1) * weight + return env_mat, diff * np.expand_dims(mask, -1), weight class EnvMat(NativeOP): @@ -74,6 +78,7 @@ def call( nlist: np.ndarray, davg: Optional[np.ndarray] = None, dstd: Optional[np.ndarray] = None, + radial_only: bool = False, ) -> Union[np.ndarray, np.ndarray]: """Compute the environment matrix. @@ -86,18 +91,23 @@ def call( atype_ext The extended aotm types. shape: nf x nall davg - The data avg. shape: nt x nnei x 4 + The data avg. shape: nt x nnei x (4 or 1) dstd - The inverse of data std. shape: nt x nnei x 4 + The inverse of data std. shape: nt x nnei x (4 or 1) + radial_only + Whether to only compute radial part of the environment matrix. + If True, the output will be of shape nf x nloc x nnei x 1. + Otherwise, the output will be of shape nf x nloc x nnei x 4. + Default: False. Returns ------- env_mat - The environment matrix. shape: nf x nloc x nnei x 4 + The environment matrix. shape: nf x nloc x nnei x (4 or 1) switch The value of switch function. shape: nf x nloc x nnei """ - em, sw = self._call(nlist, coord_ext) + em, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: @@ -106,12 +116,10 @@ def call( em /= dstd[atype] return em, sw - def _call( - self, - nlist, - coord_ext, - ): - em, diff, ww = _make_env_mat(nlist, coord_ext, self.rcut, self.rcut_smth) + def _call(self, nlist, coord_ext, radial_only): + em, diff, ww = _make_env_mat( + nlist, coord_ext, self.rcut, self.rcut_smth, radial_only + ) return em, ww def serialize( diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index e3bd37ccea..5fd644f149 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -11,7 +11,7 @@ DescrptDPA2, ) from .env_mat import ( - prod_env_mat_se_a, + prod_env_mat, ) from .gaussian_lcc import ( DescrptGaussianLcc, @@ -26,6 +26,9 @@ DescrptBlockSeA, DescrptSeA, ) +from .se_r import ( + DescrptSeR, +) __all__ = [ "DescriptorBlock", @@ -33,9 +36,10 @@ "DescrptBlockSeA", "DescrptBlockSeAtten", "DescrptSeA", + "DescrptSeR", "DescrptDPA1", "DescrptDPA2", - "prod_env_mat_se_a", + "prod_env_mat", "DescrptGaussianLcc", "DescrptBlockHybrid", "DescrptBlockRepformers", diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 29c574bb1f..91e0a2527a 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -20,7 +20,7 @@ env, ) from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSeA, + EnvMatStatSe, ) from deepmd.pt.utils.plugin import ( Plugin, @@ -129,7 +129,7 @@ def share_params(self, base_class, shared_level, resume=False): # link buffers if hasattr(self, "mean") and not resume: # in case of change params during resume - base_env = EnvMatStatSeA(base_class) + base_env = EnvMatStatSe(base_class) base_env.stats = base_class.stats for kk in base_class.get_stats(): base_env.stats[kk] += self.get_stats()[kk] diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index b3235de175..4e6ffb7785 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + import torch from deepmd.pt.utils.preprocess import ( @@ -6,7 +7,9 @@ ) -def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): +def _make_env_mat( + nlist, coord, rcut: float, ruct_smth: float, radial_only: bool = False +): """Make smooth environment matrix.""" bsz, natoms, nnei = nlist.shape coord = coord.view(bsz, -1, 3) @@ -26,35 +29,42 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): t1 = diff / length**2 weight = compute_smooth_weight(length, ruct_smth, rcut) weight = weight * mask.unsqueeze(-1) - env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight - return env_mat_se_a, diff * mask.unsqueeze(-1), weight + if radial_only: + env_mat = t0 * weight + else: + env_mat = torch.cat([t0, t1], dim=-1) * weight + return env_mat, diff * mask.unsqueeze(-1), weight -def prod_env_mat_se_a( - extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float +def prod_env_mat( + extended_coord, + nlist, + atype, + mean, + stddev, + rcut: float, + rcut_smth: float, + radial_only: bool = False, ): """Generate smooth environment matrix from atom coordinates and other context. Args: - extended_coord: Copied atom coordinates with shape [nframes, nall*3]. - atype: Atom types with shape [nframes, nloc]. - - natoms: Batched atom statisics with shape [len(sec)+2]. - - box: Batched simulation box with shape [nframes, 9]. - - mean: Average value of descriptor per element type with shape [len(sec), nnei, 4]. - - stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 4]. - - deriv_stddev: StdDev of descriptor derivative per element type with shape [len(sec), nnei, 4, 3]. + - mean: Average value of descriptor per element type with shape [len(sec), nnei, 4 or 1]. + - stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 4 or 1]. - rcut: Cut-off radius. - rcut_smth: Smooth hyper-parameter for pair force & energy. + - radial_only: Whether to return a full description or a radial-only descriptor. Returns ------- - - env_mat_se_a: Shape is [nframes, natoms[1]*nnei*4]. + - env_mat: Shape is [nframes, natoms[1]*nnei*4]. """ - nframes = extended_coord.shape[0] - _env_mat_se_a, diff, switch = _make_env_mat_se_a( - nlist, extended_coord, rcut, rcut_smth - ) # shape [n_atom, dim, 4] - t_avg = mean[atype] # [n_atom, dim, 4] - t_std = stddev[atype] # [n_atom, dim, 4] + _env_mat_se_a, diff, switch = _make_env_mat( + nlist, extended_coord, rcut, rcut_smth, radial_only + ) # shape [n_atom, dim, 4 or 1] + t_avg = mean[atype] # [n_atom, dim, 4 or 1] + t_std = stddev[atype] # [n_atom, dim, 4 or 1] env_mat_se_a = (_env_mat_se_a - t_avg) / t_std return env_mat_se_a, diff, switch diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 8aa1114fdc..ad523bcc2d 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -11,7 +11,7 @@ DescriptorBlock, ) from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.model.network.network import ( SimpleLinear, @@ -20,7 +20,7 @@ env, ) from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSeA, + EnvMatStatSe, ) from deepmd.pt.utils.utils import ( get_activation_fn, @@ -100,6 +100,7 @@ def __init__( self.nlayers = nlayers sel = [sel] if isinstance(sel, int) else sel self.nnei = sum(sel) + self.ndescrpt = self.nnei * 4 # use full descriptor. assert len(sel) == 1 self.sel = sel self.sec = self.sel @@ -222,7 +223,7 @@ def forward( nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] # nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1 - dmatrix, diff, sw = prod_env_mat_se_a( + dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, atype, @@ -279,7 +280,7 @@ def forward( def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - env_mat_stat = EnvMatStatSeA(self) + env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index ea216eddfc..033d640ad8 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -13,7 +13,7 @@ from deepmd.pt.model.descriptor import ( DescriptorBlock, - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.utils import ( env, @@ -23,7 +23,7 @@ RESERVED_PRECISON_DICT, ) from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSeA, + EnvMatStatSe, ) from deepmd.utils.env_mat_stat import ( StatItem, @@ -384,7 +384,7 @@ def __getitem__(self, key): def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - env_mat_stat = EnvMatStatSeA(self) + env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) @@ -393,9 +393,6 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None) if not self.set_davg_zero: self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) - if not self.set_davg_zero: - self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) def get_stats(self) -> Dict[str, StatItem]: """Get the statistics of the descriptor.""" @@ -428,7 +425,7 @@ def forward( del extended_atype_embd, mapping nloc = nlist.shape[1] atype = extended_atype[:, :nloc] - dmatrix, diff, sw = prod_env_mat_se_a( + dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, atype, diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 4a7469a804..0b32bd9341 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -12,7 +12,7 @@ DescriptorBlock, ) from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.model.network.network import ( NeighborWiseAttention, @@ -22,7 +22,7 @@ env, ) from deepmd.pt.utils.env_mat_stat import ( - EnvMatStatSeA, + EnvMatStatSe, ) from deepmd.utils.env_mat_stat import ( StatItem, @@ -202,7 +202,7 @@ def dim_emb(self): def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): """Update mean and stddev for descriptor elements.""" - env_mat_stat = EnvMatStatSeA(self) + env_mat_stat = EnvMatStatSe(self) if path is not None: path = path / env_mat_stat.get_hash() env_mat_stat.load_or_compute_stats(merged, path) @@ -247,7 +247,7 @@ def forward( atype = extended_atype[:, :nloc] nb = nframes nall = extended_coord.view(nb, -1, 3).shape[1] - dmatrix, diff, sw = prod_env_mat_se_a( + dmatrix, diff, sw = prod_env_mat( extended_coord, nlist, atype, diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py new file mode 100644 index 0000000000..c685640426 --- /dev/null +++ b/deepmd/pt/model/descriptor/se_r.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.descriptor import ( + prod_env_mat, +) +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSe, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +from .base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") +class DescrptSeR(BaseDescriptor, torch.nn.Module): + def __init__( + self, + rcut, + rcut_smth, + sel, + neuron=[25, 50, 100], + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], + old_impl: bool = False, + **kwargs, + ): + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.neuron = neuron + self.filter_neuron = self.neuron + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.old_impl = False # this does not support old implementation. + self.exclude_types = exclude_types + self.ntypes = len(sel) + self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) + + self.sel = sel + self.sec = torch.tensor( + np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE + ) + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 1 + + wanted_shape = (self.ntypes, self.nnei, 1) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.filter_layers_old = None + self.filter_layers = None + + filter_layers = NetworkCollection( + ndim=1, ntypes=len(sel), network_type="embedding_network" + ) + # TODO: ndim=2 if type_one_side=False + for ii in range(self.ntypes): + filter_layers[(ii,)] = EmbeddingNet( + 1, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.neuron[-1] + + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + raise NotImplementedError + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return 0 + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return False + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + env_mat_stat = EnvMatStatSe(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + def forward( + self, + coord_ext: torch.Tensor, + atype_ext: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + + """ + del mapping + nloc = nlist.shape[1] + atype = atype_ext[:, :nloc] + dmatrix, diff, sw = prod_env_mat( + coord_ext, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + True, + ) + + assert self.filter_layers is not None + dmatrix = dmatrix.view(-1, self.nnei, 1) + dmatrix = dmatrix.to(dtype=self.prec) + nfnl = dmatrix.shape[0] + # pre-allocate a shape to pass jit + xyz_scatter = torch.zeros( + [nfnl, 1, self.filter_neuron[-1]], dtype=self.prec, device=coord_ext.device + ) + + # nfnl x nnei + exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1) + for ii, ll in enumerate(self.filter_layers.networks): + # nfnl x nt + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] + # nfnl x nt x 1 + ss = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + ss = ss * mm[:, :, None] + # nfnl x nt x ng + gg = ll.forward(ss) + gg = torch.mean(gg, dim=1).unsqueeze(1) + xyz_scatter += gg + + res_rescale = 1.0 / 10.0 + result = xyz_scatter * res_rescale + result = result.view(-1, nloc, self.filter_neuron[-1]) + return ( + result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw, + ) + + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.mean = mean + self.stddev = stddev + + def serialize(self) -> dict: + return { + "@class": "Descriptor", + "type": "se_r", + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function, + # make deterministic + "precision": RESERVED_PRECISON_DICT[self.prec], + "embeddings": self.filter_layers.serialize(), + "env_mat": DPEnvMat(self.rcut, self.rcut_smth).serialize(), + "exclude_types": self.exclude_types, + "@variables": { + "davg": self["davg"].detach().cpu().numpy(), + "dstd": self["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "type_one_side": True, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeR": + data = data.copy() + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.prec, device=env.DEVICE) + + obj["davg"] = t_cvt(variables["davg"]) + obj["dstd"] = t_cvt(variables["dstd"]) + obj.filter_layers = NetworkCollection.deserialize(embeddings) + return obj diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 2f3c728c99..3af03bda97 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -13,7 +13,7 @@ get_hash, ) from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.utils import ( env, @@ -56,8 +56,8 @@ def compute_stat(self, env_mat: Dict[str, torch.Tensor]) -> Dict[str, StatItem]: return stats -class EnvMatStatSeA(EnvMatStat): - """Environmental matrix statistics for the se_a environemntal matrix. +class EnvMatStatSe(EnvMatStat): + """Environmental matrix statistics for the se_a/se_r environemntal matrix. Parameters ---------- @@ -68,6 +68,9 @@ class EnvMatStatSeA(EnvMatStat): def __init__(self, descriptor: "DescriptorBlock"): super().__init__() self.descriptor = descriptor + self.last_dim = ( + self.descriptor.ndescrpt // self.descriptor.nnei + ) # se_r=1, se_a=4 def iter( self, data: List[Dict[str, torch.Tensor]] @@ -87,14 +90,14 @@ def iter( zero_mean = torch.zeros( self.descriptor.get_ntypes(), self.descriptor.get_nsel(), - 4, + self.last_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) one_stddev = torch.ones( self.descriptor.get_ntypes(), self.descriptor.get_nsel(), - 4, + self.last_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE, ) @@ -118,7 +121,7 @@ def iter( mixed_types=self.descriptor.mixed_types(), box=box, ) - env_mat, _, _ = prod_env_mat_se_a( + env_mat, _, _ = prod_env_mat( extended_coord, nlist, atype, @@ -131,7 +134,9 @@ def iter( # reshape to nframes * nloc at the atom level, # so nframes/mixed_type do not matter env_mat = env_mat.view( - coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4 + coord.shape[0] * coord.shape[1], + self.descriptor.get_nsel(), + self.last_dim, ) atype = atype.view(coord.shape[0] * coord.shape[1]) # (1, nloc) eq (ntypes, 1), so broadcast is possible @@ -144,10 +149,11 @@ def iter( ) for type_i in range(self.descriptor.get_ntypes()): dd = env_mat[type_idx[type_i]] - dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 + dd = dd.reshape([-1, self.last_dim]) # typen_atoms * nnei, 4 env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :1] - env_mats[f"a_{type_i}"] = dd[:, 1:] + if self.last_dim == 4: + env_mats[f"a_{type_i}"] = dd[:, 1:] yield self.compute_stat(env_mats) def get_hash(self) -> str: @@ -158,9 +164,10 @@ def get_hash(self) -> str: str The hash of the environment matrix. """ + dscpt_type = "se_a" if self.last_dim == 4 else "se_r" return get_hash( { - "type": "se_a", + "type": dscpt_type, "ntypes": self.descriptor.get_ntypes(), "rcut": round(self.descriptor.get_rcut(), 2), "rcut_smth": round(self.descriptor.rcut_smth, 2), @@ -176,20 +183,30 @@ def __call__(self): all_davg = [] all_dstd = [] + for type_i in range(self.descriptor.get_ntypes()): - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], + if self.last_dim == 4: + davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] + dstdunit = [ + [ + stds[f"r_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + stds[f"a_{type_i}"], + ] + ] + elif self.last_dim == 1: + davgunit = [[avgs[f"r_{type_i}"]]] + dstdunit = [ + [ + stds[f"r_{type_i}"], + ] ] - ] davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) all_davg.append(davg) all_dstd.append(dstd) + mean = np.stack(all_davg) stddev = np.stack(all_dstd) return mean, stddev diff --git a/deepmd/tf/descriptor/se_r.py b/deepmd/tf/descriptor/se_r.py index f790d0a8fb..1a12befdf0 100644 --- a/deepmd/tf/descriptor/se_r.py +++ b/deepmd/tf/descriptor/se_r.py @@ -7,6 +7,9 @@ import numpy as np +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) from deepmd.tf.common import ( cast_precision, get_activation_func, @@ -115,8 +118,9 @@ def __init__( self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron) self.trainable = trainable self.filter_activation_fn = get_activation_func(activation_function) + self.activation_function_name = activation_function self.filter_precision = get_precision(precision) - exclude_types = exclude_types + self.orig_exclude_types = exclude_types self.exclude_types = set() for tt in exclude_types: assert len(tt) == 2 @@ -698,3 +702,93 @@ def _filter_r( result = tf.reduce_mean(xyz_scatter, axis=1) * res_rescale return result + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeR: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + if type(self) is not DescrptSeR: + raise NotImplementedError( + "Not implemented in class %s" % self.__class__.__name__ + ) + if self.embedding_net_variables is None: + raise RuntimeError("init_variables must be called before serialize") + if self.spin is not None: + raise NotImplementedError("spin is unsupported") + assert self.davg is not None + assert self.dstd is not None + # TODO: not sure how to handle type embedding - type embedding is not a model parameter, + # but instead a part of the input data. Maybe the interface should be refactored... + return { + "@class": "Descriptor", + "type": "se_r", + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel_r, + "neuron": self.filter_neuron, + "resnet_dt": self.filter_resnet_dt, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "exclude_types": list(self.orig_exclude_types), + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function_name, + "precision": self.filter_precision.name, + "embeddings": self.serialize_network( + ntypes=self.ntypes, + ndim=(1 if self.type_one_side else 2), + in_dim=1, + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.embedding_net_variables, + excluded_types=self.exclude_types, + suffix=suffix, + ), + "env_mat": EnvMat(self.rcut, self.rcut_smth).serialize(), + "@variables": { + "davg": self.davg.reshape(self.ntypes, self.nnei_r, 1), + "dstd": self.dstd.reshape(self.ntypes, self.nnei_r, 1), + }, + "spin": self.spin, + } diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py new file mode 100644 index 0000000000..354ae1cc99 --- /dev/null +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.se_r import DescrptSeR as DescrptSeRPT +else: + DescrptSeAPT = None +if INSTALLED_TF: + from deepmd.tf.descriptor.se_r import DescrptSeR as DescrptSeRTF +else: + DescrptSeATF = None +from deepmd.utils.argcheck import ( + descrpt_se_r_args, +) + + +@parameterized( + (True, False), # resnet_dt + (True, False), # type_one_side + ([], [[0, 1]]), # excluded_types + ("float32", "float64"), # precision +) +class TestSeA(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return { + "sel": [10, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "precision": precision, + "seed": 1145141919810, + } + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or CommonTest.skip_dp + + tf_class = DescrptSeRTF + dp_class = DescrptSeRDP + pt_class = DescrptSeRPT + args = descrpt_se_r_args() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt/model/test_descriptor.py b/source/tests/pt/model/test_descriptor.py index 529a83ac6d..ffad27201a 100644 --- a/source/tests/pt/model/test_descriptor.py +++ b/source/tests/pt/model/test_descriptor.py @@ -14,7 +14,7 @@ ) from deepmd.pt.model.descriptor import ( - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.utils import ( dp_random, @@ -155,7 +155,7 @@ def test_consistency(self): mixed_types=False, box=self.pt_batch["box"].to(env.DEVICE), ) - my_d, _, _ = prod_env_mat_se_a( + my_d, _, _ = prod_env_mat( extended_coord, nlist, atype, diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py new file mode 100644 index 0000000000..c999f06863 --- /dev/null +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR +from deepmd.pt.model.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +# to be merged with the tf test case +class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + exclude_mask=em, + ).to(env.DEVICE) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, sw1], [rd2, sw2]): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + torch.jit.script(dd0) + torch.jit.script(dd1) diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index ee262e7ee5..fee3fd6fea 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -8,7 +8,7 @@ EnvMat, ) from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat_se_a, + prod_env_mat, ) from deepmd.pt.utils import ( env, @@ -99,7 +99,7 @@ def test_consistency( dstd = 0.1 + np.abs(dstd) em0 = EnvMat(self.rcut, self.rcut_smth) mm0, ww0 = em0.call(self.coord_ext, self.atype_ext, self.nlist, davg, dstd) - mm1, _, ww1 = prod_env_mat_se_a( + mm1, _, ww1 = prod_env_mat( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE),