diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 1dbb14961e..6db27e9fc9 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -4,11 +4,15 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EnvMat, NetworkCollection, @@ -787,6 +791,7 @@ def call( The smooth switch function. shape: nf x nloc x nnei """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = coord_ext.reshape(nframes, -1).shape[1] // 3 @@ -823,7 +828,7 @@ def call( g1_ext, mapping, ) - g1 = np.concatenate([g1, g1_three_body], axis=-1) + g1 = xp.concatenate([g1, g1_three_body], axis=-1) # linear to change shape g1 = self.g1_shape_tranform(g1) if self.add_tebd_to_repinit_out: @@ -831,8 +836,8 @@ def call( g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 assert mapping is not None - mapping_ext = np.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) - g1_ext = np.take_along_axis(g1, mapping_ext, axis=1) + mapping_ext = xp.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) + g1_ext = xp.take_along_axis(g1, mapping_ext, axis=1) # repformer g1, g2, h2, rot_mat, sw = self.repformers( nlist_dict[ @@ -846,7 +851,7 @@ def call( mapping, ) if self.concat_output_tebd: - g1 = np.concatenate([g1, g1_inp], axis=-1) + g1 = xp.concatenate([g1, g1_inp], axis=-1) return g1, rot_mat, g2, h2, sw def serialize(self) -> dict: @@ -883,8 +888,8 @@ def serialize(self) -> dict: "embeddings": repinit.embeddings.serialize(), "env_mat": EnvMat(repinit.rcut, repinit.rcut_smth).serialize(), "@variables": { - "davg": repinit["davg"], - "dstd": repinit["dstd"], + "davg": to_numpy_array(repinit["davg"]), + "dstd": to_numpy_array(repinit["dstd"]), }, } if repinit.tebd_input_mode in ["strip"]: @@ -896,8 +901,8 @@ def serialize(self) -> dict: "repformer_layers": [layer.serialize() for layer in repformers.layers], "env_mat": EnvMat(repformers.rcut, repformers.rcut_smth).serialize(), "@variables": { - "davg": repformers["davg"], - "dstd": repformers["dstd"], + "davg": to_numpy_array(repformers["davg"]), + "dstd": to_numpy_array(repformers["dstd"]), }, } data.update( @@ -913,8 +918,8 @@ def serialize(self) -> dict: repinit_three_body.rcut, repinit_three_body.rcut_smth ).serialize(), "@variables": { - "davg": repinit_three_body["davg"], - "dstd": repinit_three_body["dstd"], + "davg": to_numpy_array(repinit_three_body["davg"]), + "dstd": to_numpy_array(repinit_three_body["dstd"]), }, } if repinit_three_body.tebd_input_mode in ["strip"]: diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index ef79ecdd28..254b186daf 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -5,12 +5,16 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( EnvMat, PairExcludeMask, @@ -360,8 +364,9 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.where(exclude_mask, nlist, -1) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -371,7 +376,7 @@ def call( nlist_mask = nlist != -1 # nf x nloc x nnei sw = sw.reshape(nf, nloc, nnei) - sw = np.where(nlist_mask, sw, 0.0) + sw = xp.where(nlist_mask, sw, 0.0) # nf x nloc x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] assert list(atype_embd.shape) == [nf, nloc, self.g1_dim] @@ -379,22 +384,22 @@ def call( g1 = self.act(atype_embd) # nf x nloc x nnei x 1, nf x nloc x nnei x 3 if not self.direct_dist: - g2, h2 = np.split(dmatrix, [1], axis=-1) + g2, h2 = xp.split(dmatrix, [1], axis=-1) else: - g2, h2 = np.linalg.norm(diff, axis=-1, keepdims=True), diff + g2, h2 = xp.linalg.norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nf x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 # if a neighbor is real or not is indicated by nlist_mask - nlist[nlist == -1] = 0 + nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist) # nf x nall x ng1 - mapping = np.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) + mapping = xp.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) for idx, ll in enumerate(self.layers): # g1: nf x nloc x ng1 # g1_ext: nf x nall x ng1 - g1_ext = np.take_along_axis(g1, mapping, axis=1) + g1_ext = xp.take_along_axis(g1, mapping, axis=1) g1, g2, h2 = ll.call( g1_ext, g2, @@ -415,7 +420,7 @@ def call( use_sqrt_nnei=self.use_sqrt_nnei, ) # (nf x nloc) x ng2 x 3 - rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) + rot_mat = xp.transpose(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw def has_message_passing(self) -> bool: @@ -426,6 +431,72 @@ def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + @classmethod + def deserialize(cls, data): + """Deserialize the descriptor block.""" + data = data.copy() + g2_embd = NativeLayer.deserialize(data.pop("g2_embd")) + layers = [RepformerLayer.deserialize(dd) for dd in data.pop("repformer_layers")] + env_mat = EnvMat.deserialize(data.pop("env_mat")) + variables = data.pop("@variables") + davg = variables["davg"] + dstd = variables["dstd"] + obj = cls(**data) + obj.g2_embd = g2_embd + obj.layers = layers + obj.env_mat = env_mat + obj.mean = davg + obj.stddev = dstd + return obj + + def serialize(self): + """Serialize the descriptor block.""" + return { + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "ntypes": self.ntypes, + "nlayers": self.nlayers, + "g1_dim": self.g1_dim, + "g2_dim": self.g2_dim, + "axis_neuron": self.axis_neuron, + "direct_dist": self.direct_dist, + "update_g1_has_conv": self.update_g1_has_conv, + "update_g1_has_drrd": self.update_g1_has_drrd, + "update_g1_has_grrg": self.update_g1_has_grrg, + "update_g1_has_attn": self.update_g1_has_attn, + "update_g2_has_g1g1": self.update_g2_has_g1g1, + "update_g2_has_attn": self.update_g2_has_attn, + "update_h2": self.update_h2, + "attn1_hidden": self.attn1_hidden, + "attn1_nhead": self.attn1_nhead, + "attn2_hidden": self.attn2_hidden, + "attn2_nhead": self.attn2_nhead, + "attn2_has_gate": self.attn2_has_gate, + "activation_function": self.activation_function, + "update_style": self.update_style, + "update_residual": self.update_residual, + "update_residual_init": self.update_residual_init, + "set_davg_zero": self.set_davg_zero, + "smooth": self.smooth, + "exclude_types": self.exclude_types, + "env_protection": self.env_protection, + "precision": self.precision, + "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, + "ln_eps": self.ln_eps, + # variables + "g2_embd": self.g2_embd.serialize(), + "repformer_layers": [layer.serialize() for layer in self.layers], + "env_mat": EnvMat(self.rcut, self.rcut_smth).serialize(), + "@variables": { + "davg": to_numpy_array(self["davg"]), + "dstd": to_numpy_array(self["dstd"]), + }, + } + # translated by GPT and modified def get_residual( @@ -487,14 +558,15 @@ def _make_nei_g1( gg1: np.ndarray Neighbor-wise atomic invariant rep, with shape [nf, nloc, nnei, ng1]. """ + xp = array_api_compat.array_namespace(g1_ext, nlist) # nlist: nf x nloc x nnei nf, nloc, nnei = nlist.shape # g1_ext: nf x nall x ng1 ng1 = g1_ext.shape[-1] # index: nf x (nloc x nnei) x ng1 - index = np.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) + index = xp.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) # gg1 : nf x (nloc x nnei) x ng1 - gg1 = np.take_along_axis(g1_ext, index, axis=1) + gg1 = xp.take_along_axis(g1_ext, index, axis=1) # gg1 : nf x nloc x nnei x ng1 gg1 = gg1.reshape(nf, nloc, nnei, ng1) return gg1 @@ -514,7 +586,8 @@ def _apply_nlist_mask( nlist_mask Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. """ - masked_gg = np.where(nlist_mask[:, :, :, None], gg, 0.0) + xp = array_api_compat.array_namespace(gg, nlist_mask) + masked_gg = xp.where(nlist_mask[:, :, :, None], gg, 0.0) return masked_gg @@ -570,6 +643,7 @@ def _cal_hg( hg The transposed rotation matrix, with shape [nf, nloc, 3, ng]. """ + xp = array_api_compat.array_namespace(g, h, nlist_mask, sw) # g: nf x nloc x nnei x ng # h: nf x nloc x nnei x 3 # msk: nf x nloc x nnei @@ -580,21 +654,21 @@ def _cal_hg( if not smooth: # nf x nloc if not use_sqrt_nnei: - invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / (epsilon + xp.sum(nlist_mask, axis=-1)) else: - invnnei = 1.0 / (epsilon + np.sqrt(np.sum(nlist_mask, axis=-1))) + invnnei = 1.0 / (epsilon + xp.sqrt(xp.sum(nlist_mask, axis=-1))) # nf x nloc x 1 x 1 - invnnei = invnnei[:, :, np.newaxis, np.newaxis] + invnnei = invnnei[:, :, xp.newaxis, xp.newaxis] else: g = _apply_switch(g, sw) if not use_sqrt_nnei: - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1, 1), dtype=g.dtype) else: - invnnei = (1.0 / (float(nnei) ** 0.5)) * np.ones( + invnnei = (1.0 / (float(nnei) ** 0.5)) * xp.ones( (nf, nloc, 1, 1), dtype=g.dtype ) # nf x nloc x 3 x ng - hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei + hg = xp.matmul(xp.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei return hg @@ -614,12 +688,13 @@ def _cal_grrg(hg: np.ndarray, axis_neuron: int) -> np.ndarray: grrg Atomic invariant rep, with shape [nf, nloc, (axis_neuron * ng)]. """ + xp = array_api_compat.array_namespace(hg) # nf x nloc x 3 x ng nf, nloc, _, ng = hg.shape # nf x nloc x 3 x axis - hgm = np.split(hg, [axis_neuron], axis=-1)[0] + hgm = xp.split(hg, [axis_neuron], axis=-1)[0] # nf x nloc x axis_neuron x ng - grrg = np.matmul(np.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) + grrg = xp.matmul(xp.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) # nf x nloc x (axis_neuron * ng) grrg = grrg.reshape(nf, nloc, axis_neuron * ng) return grrg @@ -718,6 +793,7 @@ def call( nlist_mask: np.ndarray, # nf x nloc x nnei sw: np.ndarray, # nf x nloc x nnei ) -> np.ndarray: + xp = array_api_compat.array_namespace(g2, h2, nlist_mask, sw) ( nf, nloc, @@ -728,41 +804,41 @@ def call( # nf x nloc x nnei x nd x (nh x 2) g2qk = self.mapqk(g2).reshape(nf, nloc, nnei, nd, nh * 2) # nf x nloc x (nh x 2) x nnei x nd - g2qk = np.transpose(g2qk, (0, 1, 4, 2, 3)) + g2qk = xp.transpose(g2qk, (0, 1, 4, 2, 3)) # nf x nloc x nh x nnei x nd - g2q, g2k = np.split(g2qk, [nh], axis=2) + g2q, g2k = xp.split(g2qk, [nh], axis=2) # g2q = np.linalg.norm(g2q, axis=-1) # g2k = np.linalg.norm(g2k, axis=-1) # nf x nloc x nh x nnei x nnei - attnw = np.matmul(g2q, np.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 + attnw = xp.matmul(g2q, xp.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 if self.has_gate: - gate = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))).reshape( + gate = xp.matmul(h2, xp.transpose(h2, axes=(0, 1, 3, 2))).reshape( nf, nloc, 1, nnei, nnei ) attnw = attnw * gate # mask the attenmap, nf x nloc x 1 x 1 x nnei - attnw_mask = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=2) + attnw_mask = ~xp.expand_dims(xp.expand_dims(nlist_mask, axis=2), axis=2) # mask the attenmap, nf x nloc x 1 x nnei x 1 - attnw_mask_c = ~np.expand_dims(np.expand_dims(nlist_mask, axis=2), axis=-1) + attnw_mask_c = ~xp.expand_dims(xp.expand_dims(nlist_mask, axis=2), axis=-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ :, :, None, None, : ] - self.attnw_shift else: - attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = xp.where(attnw_mask, -xp.inf, attnw) attnw = np_softmax(attnw, axis=-1) - attnw = np.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, 0.0, attnw) # nf x nloc x nh x nnei x nnei - attnw = np.where(attnw_mask_c, 0.0, attnw) + attnw = xp.where(attnw_mask_c, 0.0, attnw) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] # nf x nloc x nnei x nnei - h2h2t = np.matmul(h2, np.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 + h2h2t = xp.matmul(h2, xp.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 # nf x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = np.exp(g2qk - np.max(g2qk, axis=-1, keepdims=True)) # nf x nloc x nnei x nnei x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)) + ret = xp.transpose(ret, (0, 1, 3, 4, 2)) return ret def serialize(self) -> dict: @@ -835,19 +911,20 @@ def call( AA: np.ndarray, # nf x nloc x nnei x nnei x nh g2: np.ndarray, # nf x nloc x nnei x ng2 ) -> np.ndarray: + xp = array_api_compat.array_namespace(AA, g2) nf, nloc, nnei, ng2 = g2.shape nh = self.head_num # nf x nloc x nnei x ng2 x nh g2v = self.mapv(g2).reshape(nf, nloc, nnei, ng2, nh) # nf x nloc x nh x nnei x ng2 - g2v = np.transpose(g2v, (0, 1, 4, 2, 3)) + g2v = xp.transpose(g2v, (0, 1, 4, 2, 3)) # g2v = np.linalg.norm(g2v, axis=-1) # nf x nloc x nh x nnei x nnei - AA = np.transpose(AA, (0, 1, 4, 2, 3)) + AA = xp.transpose(AA, (0, 1, 4, 2, 3)) # nf x nloc x nh x nnei x ng2 - ret = np.matmul(AA, g2v) + ret = xp.matmul(AA, g2v) # nf x nloc x nnei x ng2 x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + ret = xp.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) # nf x nloc x nnei x ng2 return self.head_map(ret) @@ -910,19 +987,20 @@ def call( AA: np.ndarray, # nf x nloc x nnei x nnei x nh h2: np.ndarray, # nf x nloc x nnei x 3 ) -> np.ndarray: + xp = array_api_compat.array_namespace(AA, h2) nf, nloc, nnei, _ = h2.shape nh = self.head_num # nf x nloc x nh x nnei x nnei - AA = np.transpose(AA, (0, 1, 4, 2, 3)) - h2m = np.expand_dims(h2, axis=2) + AA = xp.transpose(AA, (0, 1, 4, 2, 3)) + h2m = xp.expand_dims(h2, axis=2) # nf x nloc x nh x nnei x 3 - h2m = np.tile(h2m, (1, 1, nh, 1, 1)) + h2m = xp.tile(h2m, (1, 1, nh, 1, 1)) # nf x nloc x nh x nnei x 3 - ret = np.matmul(AA, h2m) + ret = xp.matmul(AA, h2m) # nf x nloc x nnei x 3 x nh - ret = np.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) + ret = xp.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) # nf x nloc x nnei x 3 - return np.squeeze(self.head_map(ret), axis=-1) + return xp.squeeze(self.head_map(ret), axis=-1) def serialize(self) -> dict: """Serialize the networks to a dict. @@ -1005,6 +1083,7 @@ def call( nlist_mask: np.ndarray, # nf x nloc x nnei sw: np.ndarray, # nf x nloc x nnei ) -> np.ndarray: + xp = array_api_compat.array_namespace(g1, gg1, nlist_mask, sw) nf, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] @@ -1012,39 +1091,39 @@ def call( # nf x nloc x nd x nh g1q = self.mapq(g1).reshape(nf, nloc, nd, nh) # nf x nloc x nh x nd - g1q = np.transpose(g1q, (0, 1, 3, 2)) + g1q = xp.transpose(g1q, (0, 1, 3, 2)) # nf x nloc x nnei x (nd+ni) x nh gg1kv = self.mapkv(gg1).reshape(nf, nloc, nnei, nd + ni, nh) - gg1kv = np.transpose(gg1kv, (0, 1, 4, 2, 3)) + gg1kv = xp.transpose(gg1kv, (0, 1, 4, 2, 3)) # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 - gg1k, gg1v = np.split(gg1kv, [nd], axis=-1) + gg1k, gg1v = xp.split(gg1kv, [nd], axis=-1) # nf x nloc x nh x 1 x nnei attnw = ( - np.matmul( - np.expand_dims(g1q, axis=-2), np.transpose(gg1k, axes=(0, 1, 2, 4, 3)) + xp.matmul( + xp.expand_dims(g1q, axis=-2), xp.transpose(gg1k, axes=(0, 1, 2, 4, 3)) ) / nd**0.5 ) # nf x nloc x nh x nnei - attnw = np.squeeze(attnw, axis=-2) + attnw = xp.squeeze(attnw, axis=-2) # mask the attenmap, nf x nloc x 1 x nnei - attnw_mask = ~np.expand_dims(nlist_mask, axis=-2) + attnw_mask = ~xp.expand_dims(nlist_mask, axis=-2) # nf x nloc x nh x nnei if self.smooth: - attnw = (attnw + self.attnw_shift) * np.expand_dims( + attnw = (attnw + self.attnw_shift) * xp.expand_dims( sw, axis=-2 ) - self.attnw_shift else: - attnw = np.where(attnw_mask, -np.inf, attnw) + attnw = xp.where(attnw_mask, -xp.inf, attnw) attnw = np_softmax(attnw, axis=-1) - attnw = np.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, 0.0, attnw) if self.smooth: - attnw = attnw * np.expand_dims(sw, axis=-2) + attnw = attnw * xp.expand_dims(sw, axis=-2) # nf x nloc x nh x ng1 ret = ( - np.matmul(np.expand_dims(attnw, axis=-2), gg1v) + xp.matmul(xp.expand_dims(attnw, axis=-2), gg1v) .squeeze(-2) .reshape(nf, nloc, nh * ni) ) @@ -1178,12 +1257,12 @@ def __init__( ], "'update_residual_init' only support 'norm' or 'const'!" self.update_residual = update_residual self.update_residual_init = update_residual_init - self.g1_residual = [] - self.g2_residual = [] - self.h2_residual = [] + g1_residual = [] + g2_residual = [] + h2_residual = [] if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1217,7 +1296,7 @@ def __init__( seed=child_seed(seed, 2), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1234,7 +1313,7 @@ def __init__( seed=child_seed(seed, 15), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1263,7 +1342,7 @@ def __init__( seed=child_seed(seed, 4), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1281,7 +1360,7 @@ def __init__( seed=child_seed(seed, 5), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1312,7 +1391,7 @@ def __init__( seed=child_seed(seed, 9), ) if self.update_style == "res_residual": - self.g2_residual.append( + g2_residual.append( get_residual( g2_dim, self.update_residual, @@ -1327,7 +1406,7 @@ def __init__( g2_dim, attn2_nhead, precision=precision, seed=child_seed(seed, 11) ) if self.update_style == "res_residual": - self.h2_residual.append( + h2_residual.append( get_residual( 1, self.update_residual, @@ -1346,7 +1425,7 @@ def __init__( seed=child_seed(seed, 13), ) if self.update_style == "res_residual": - self.g1_residual.append( + g1_residual.append( get_residual( g1_dim, self.update_residual, @@ -1356,6 +1435,10 @@ def __init__( ) ) + self.g1_residual = g1_residual + self.g2_residual = g2_residual + self.h2_residual = h2_residual + def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: @@ -1408,6 +1491,7 @@ def _update_g1_conv( The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nf x nloc x nnei. """ + xp = array_api_compat.array_namespace(gg1, g2, nlist_mask, sw) assert self.proj_g1g2 is not None nf, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] @@ -1423,20 +1507,20 @@ def _update_g1_conv( if not self.smooth: # normalized by number of neighbors, not smooth # nf x nloc - invnnei = 1.0 / (self.epsilon + np.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / (self.epsilon + xp.sum(nlist_mask, axis=-1)) # nf x nloc x 1 - invnnei = invnnei[:, :, np.newaxis] + invnnei = invnnei[:, :, xp.newaxis] else: gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1), dtype=gg1.dtype) if not self.g1_out_conv: # nf x nloc x ng2 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei else: # nf x nloc x ng1 g2 = self.proj_g1g2(g2).reshape(nf, nloc, nnei, ng1) # nb x nloc x ng1 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei return g1_11 def _update_g2_g1g1( @@ -1461,7 +1545,8 @@ def _update_g2_g1g1( The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, and remains 0 beyond rcut, with shape nf x nloc x nnei. """ - ret = np.expand_dims(g1, axis=-2) * gg1 + xp = array_api_compat.array_namespace(g1, gg1, nlist_mask, sw) + ret = xp.expand_dims(g1, axis=-2) * gg1 # nf x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: @@ -1493,6 +1578,7 @@ def call( g2: nf x nloc x nnei x ng2 updated pair-atom channel, invariant h2: nf x nloc x nnei x 3 updated pair-atom channel, equivariant """ + xp = array_api_compat.array_namespace(g1_ext, g2, h2, nlist, nlist_mask, sw) cal_gg1 = ( self.update_g1_has_drrd or self.update_g1_has_conv @@ -1502,14 +1588,14 @@ def call( nf, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] - g1, _ = np.split(g1_ext, [nloc], axis=1) + g1, _ = xp.split(g1_ext, [nloc], axis=1) assert (nf, nloc) == g1.shape[:2] assert (nf, nloc, nnei) == h2.shape[:3] - g2_update: list[np.ndarray] = [g2] - h2_update: list[np.ndarray] = [h2] - g1_update: list[np.ndarray] = [g1] - g1_mlp: list[np.ndarray] = [g1] if not self.g1_out_mlp else [] + g2_update: list[xp.ndarray] = [g2] + h2_update: list[xp.ndarray] = [h2] + g1_update: list[xp.ndarray] = [g1] + g1_mlp: list[xp.ndarray] = [g1] if not self.g1_out_mlp else [] if self.g1_out_mlp: assert self.g1_self_mlp is not None g1_self_mlp = self.act(self.g1_self_mlp(g1)) @@ -1592,7 +1678,7 @@ def call( # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd - g1_1 = self.act(self.linear1(np.concatenate(g1_mlp, axis=-1))) + g1_1 = self.act(self.linear1(xp.concatenate(g1_mlp, axis=-1))) g1_update.append(g1_1) if self.update_g1_has_attn: @@ -1752,9 +1838,9 @@ def serialize(self) -> dict: if self.update_style == "res_residual": data.update( { - "g1_residual": self.g1_residual, - "g2_residual": self.g2_residual, - "h2_residual": self.h2_residual, + "g1_residual": [to_numpy_array(aa) for aa in self.g1_residual], + "g2_residual": [to_numpy_array(aa) for aa in self.g2_residual], + "h2_residual": [to_numpy_array(aa) for aa in self.h2_residual], } ) return data diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index ca89c23968..298f823690 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -5,12 +5,20 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) +from deepmd.dpmodel.common import ( + get_xp_precision, + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -26,9 +34,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -318,11 +323,15 @@ def call( sw The smooth switch function. """ + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) del mapping nf, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nf, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] + atype_embd_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nf, nall, self.tebd_dim), + ) # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] grrg, g2, h2, rot_mat, sw = self.se_ttebd( @@ -334,8 +343,8 @@ def call( ) # nf x nloc x (ng + tebd_dim) if self.concat_output_tebd: - grrg = np.concatenate( - [grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1 + grrg = xp.concat( + [grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1 ) return grrg, rot_mat, None, None, sw @@ -368,8 +377,8 @@ def serialize(self) -> dict: "env_protection": obj.env_protection, "smooth": self.smooth, "@variables": { - "davg": obj["davg"], - "dstd": obj["dstd"], + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), }, "trainable": self.trainable, } @@ -491,12 +500,12 @@ def __init__( else: self.embd_input_dim = 1 - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings[0] = EmbeddingNet( + embeddings[0] = EmbeddingNet( self.embd_input_dim, self.neuron, self.activation_function, @@ -504,13 +513,14 @@ def __init__( self.precision, seed=child_seed(seed, 0), ) + self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: - self.embeddings_strip = NetworkCollection( + embeddings_strip = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings_strip[0] = EmbeddingNet( + embeddings_strip[0] = EmbeddingNet( self.tebd_dim_input, self.neuron, self.activation_function, @@ -518,6 +528,7 @@ def __init__( self.precision, seed=child_seed(seed, 1), ) + self.embeddings_strip = embeddings_strip else: self.embeddings_strip = None @@ -652,6 +663,7 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -659,47 +671,49 @@ def call( nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei - nlist = nlist.reshape(nf * nloc, nnei) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.reshape(nlist, (nf * nloc, nnei)) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei nlist_mask = nlist != -1 # nfnl x nnei x 1 - sw = np.where(nlist_mask[:, :, None], sw.reshape(nf * nloc, nnei, 1), 0.0) + sw = xp.where( + nlist_mask[:, :, None], + xp.reshape(sw, (nf * nloc, nnei, 1)), + xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype), + ) # nfnl x nnei x 4 - dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) # nfnl x nnei x 4 rr = dmatrix - rr = rr * exclude_mask[:, :, None] + rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nt_i x 3 rr_i = rr[:, :, 1:] # nfnl x nt_j x 3 rr_j = rr[:, :, 1:] # nfnl x nt_i x nt_j - env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1) # nfnl x nt_i x nt_j x 1 - ss = np.expand_dims(env_ij, axis=-1) + ss = env_ij[..., None] - nlist_masked = np.where(nlist_mask, nlist, 0) - index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) + index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( - nf * nloc, nnei, self.tebd_dim + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) ) # nfnl x nt_i x nt_j x tebd_dim - nlist_tebd_i = np.tile( - np.expand_dims(atype_embd_nlist, axis=2), [1, 1, self.nnei, 1] - ) - nlist_tebd_j = np.tile( - np.expand_dims(atype_embd_nlist, axis=1), [1, self.nnei, 1, 1] - ) + nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1)) + nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1)) ng = self.neuron[-1] if self.tebd_input_mode in ["concat"]: # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) - ss = np.concatenate([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) + ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) # nfnl x nt_i x nt_j x ng gg = self.cal_g(ss, 0) elif self.tebd_input_mode in ["strip"]: @@ -707,14 +721,14 @@ def call( gg_s = self.cal_g(ss, 0) assert self.embeddings_strip is not None # nfnl x nt_i x nt_j x (tebd_dim * 2) - tt = np.concatenate([nlist_tebd_i, nlist_tebd_j], axis=-1) + tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1) # nfnl x nt_i x nt_j x ng gg_t = self.cal_g_strip(tt, 0) if self.smooth: gg_t = ( gg_t - * sw.reshape(nf * nloc, self.nnei, 1, 1) - * sw.reshape(nf * nloc, 1, self.nnei, 1) + * xp.reshape(sw, (nf * nloc, self.nnei, 1, 1)) + * xp.reshape(sw, (nf * nloc, 1, self.nnei, 1)) ) # nfnl x nt_i x nt_j x ng gg = gg_s * gg_t + gg_s @@ -722,12 +736,12 @@ def call( raise NotImplementedError # nfnl x ng - res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = xp.sum(env_ij[:, :, :, None] * gg[:, :, :, :], axis=(1, 2)) res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng - result = res_ij.reshape(nf, nloc, self.filter_neuron[-1]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) + result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, @@ -743,3 +757,61 @@ def has_message_passing(self) -> bool: def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self + data = { + "@class": "Descriptor", + "type": "se_e3_tebd", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "env_mat": obj.env_mat.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "smooth": obj.smooth, + "@variables": { + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), + }, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeTTebd": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + se_ttebd = cls(**data) + + se_ttebd["davg"] = variables["davg"] + se_ttebd["dstd"] = variables["dstd"] + se_ttebd.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + se_ttebd.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) + + return se_ttebd diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index b827032588..68c9dd7a97 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -215,30 +215,30 @@ def build_multiple_neighbor_list( value being the corresponding nlist. """ + xp = array_api_compat.array_namespace(coord, nlist) assert len(rcuts) == len(nsels) if len(rcuts) == 0: return {} nb, nloc, nsel = nlist.shape if nsel < nsels[-1]: - pad = -1 * np.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) - nlist = np.concatenate([nlist, pad], axis=-1) + pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] coord1 = coord.reshape(nb, -1, 3) nall = coord1.shape[1] coord0 = coord1[:, :nloc, :] nlist_mask = nlist == -1 - tnlist_0 = nlist.copy() - tnlist_0[nlist_mask] = 0 - index = np.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) - coord2 = np.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) + tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist) + index = xp.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) + coord2 = xp.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) diff = coord2 - coord0[:, :, None, :] - rr = np.linalg.norm(diff, axis=-1) - rr = np.where(nlist_mask, float("inf"), rr) + rr = xp.linalg.norm(diff, axis=-1) + rr = xp.where(nlist_mask, float("inf"), rr) nlist0 = nlist ret = {} for rc, ns in zip(rcuts[::-1], nsels[::-1]): - tnlist_1 = np.copy(nlist0[:, :, :ns]) - tnlist_1[rr[:, :, :ns] > rc] = -1 + tnlist_1 = nlist0[:, :, :ns] + tnlist_1 = xp.where(rr[:, :, :ns] > rc, xp.full_like(tnlist_1, -1), tnlist_1) ret[get_multiple_nlist_key(rc, ns)] = tnlist_1 return ret diff --git a/deepmd/jax/descriptor/dpa2.py b/deepmd/jax/descriptor/dpa2.py new file mode 100644 index 0000000000..0e49689e94 --- /dev/null +++ b/deepmd/jax/descriptor/dpa2.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.dpmodel.utils.network import Identity as IdentityDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.jax.common import ( + ArrayAPIVariable, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.descriptor.dpa1 import ( + DescrptBlockSeAtten, +) +from deepmd.jax.descriptor.repformers import ( + DescrptBlockRepformers, +) +from deepmd.jax.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd, +) +from deepmd.jax.utils.network import ( + NativeLayer, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(DescrptDPA2DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"repinit"}: + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name in {"repinit_three_body"}: + if value is not None: + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name in {"repformers"}: + value = DescrptBlockRepformers.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"g1_shape_tranform", "tebd_transform"}: + if value is None: + pass + elif isinstance(value, NativeLayerDP): + value = NativeLayer.deserialize(value.serialize()) + elif isinstance(value, IdentityDP): + value = IdentityDP.deserialize(value.serialize()) + else: + raise ValueError(f"Unknown layer type: {type(value)}") + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/repformers.py b/deepmd/jax/descriptor/repformers.py new file mode 100644 index 0000000000..1d3a7fbb29 --- /dev/null +++ b/deepmd/jax/descriptor/repformers.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repformers import ( + Atten2EquiVarApply as Atten2EquiVarApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import Atten2Map as Atten2MapDP +from deepmd.dpmodel.descriptor.repformers import ( + Atten2MultiHeadApply as Atten2MultiHeadApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers as DescrptBlockRepformersDP, +) +from deepmd.dpmodel.descriptor.repformers import LocalAtten as LocalAttenDP +from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP +from deepmd.jax.common import ( + ArrayAPIVariable, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + LayerNorm, + NativeLayer, +) + + +class DescrptBlockRepformers(DescrptBlockRepformersDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"layers"}: + value = [RepformerLayer.deserialize(layer.serialize()) for layer in value] + elif name == "g2_embd": + value = NativeLayer.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class Atten2Map(Atten2MapDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapqk"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2EquiVarApply(Atten2EquiVarApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class LocalAtten(LocalAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapq", "mapkv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class RepformerLayer(RepformerLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"g1_residual", "g2_residual", "h2_residual"}: + value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value] + elif name in {"attn2g_map"}: + if value is not None: + value = Atten2Map.deserialize(value.serialize()) + elif name in {"attn2_mh_apply"}: + if value is not None: + value = Atten2MultiHeadApply.deserialize(value.serialize()) + elif name in {"attn2_lm"}: + if value is not None: + value = LayerNorm.deserialize(value.serialize()) + elif name in {"attn2_ev_apply"}: + if value is not None: + value = Atten2EquiVarApply.deserialize(value.serialize()) + elif name in {"loc_attn"}: + if value is not None: + value = LocalAtten.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/se_t_tebd.py b/deepmd/jax/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..84e3d3f084 --- /dev/null +++ b/deepmd/jax/descriptor/se_t_tebd.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@flax_module +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +@BaseDescriptor.register("se_e3_tebd") +@flax_module +class DescrptSeTTebd(DescrptSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_ttebd": + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 53f9ce4200..ff46b8296c 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -15,6 +15,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -28,6 +29,11 @@ else: DescrptDPA2PT = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2JAX +else: + DescrptDPA2JAX = None + # not implemented DescrptDPA2TF = None @@ -269,9 +275,12 @@ def skip_tf(self) -> bool: ) = self.param return True + skip_jax = not INSTALLED_JAX + tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT + jax_class = DescrptDPA2JAX args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -367,6 +376,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)