diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 6db27e9fc9..200747c0ef 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -10,6 +10,9 @@ from deepmd.dpmodel import ( NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.common import ( to_numpy_array, ) @@ -794,7 +797,7 @@ def call( xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nframes, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nframes, -1)).shape[1] // 3 # nlists nlist_dict = build_multiple_neighbor_list( coord_ext, @@ -803,7 +806,10 @@ def call( self.nsel_list, ) # repinit - g1_ext = self.type_embedding.call()[atype_ext] + g1_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nframes, nall, self.tebd_dim), + ) g1_inp = g1_ext[:, :nloc, :] g1, _, _, _, _ = self.repinit( nlist_dict[ @@ -828,7 +834,7 @@ def call( g1_ext, mapping, ) - g1 = xp.concatenate([g1, g1_three_body], axis=-1) + g1 = xp.concat([g1, g1_three_body], axis=-1) # linear to change shape g1 = self.g1_shape_tranform(g1) if self.add_tebd_to_repinit_out: @@ -836,8 +842,10 @@ def call( g1 = g1 + self.tebd_transform(g1_inp) # mapping g1 assert mapping is not None - mapping_ext = xp.tile(mapping.reshape(nframes, nall, 1), (1, 1, g1.shape[-1])) - g1_ext = xp.take_along_axis(g1, mapping_ext, axis=1) + mapping_ext = xp.tile( + xp.reshape(mapping, (nframes, nall, 1)), (1, 1, g1.shape[-1]) + ) + g1_ext = xp_take_along_axis(g1, mapping_ext, axis=1) # repformer g1, g2, h2, rot_mat, sw = self.repformers( nlist_dict[ @@ -851,7 +859,7 @@ def call( mapping, ) if self.concat_output_tebd: - g1 = xp.concatenate([g1, g1_inp], axis=-1) + g1 = xp.concat([g1, g1_inp], axis=-1) return g1, rot_mat, g2, h2, sw def serialize(self) -> dict: diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 254b186daf..29802b8b65 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -12,6 +12,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.common import ( to_numpy_array, ) @@ -42,6 +45,28 @@ ) +def xp_transpose_01423(x): + xp = array_api_compat.array_namespace(x) + x_shape2 = x.shape[2] + x_shape3 = x.shape[3] + x_shape4 = x.shape[4] + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape2 * x_shape3, x_shape4)) + x = xp.matrix_transpose(x) + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape4, x_shape2, x_shape3)) + return x + + +def xp_transpose_01342(x): + xp = array_api_compat.array_namespace(x) + x_shape2 = x.shape[2] + x_shape3 = x.shape[3] + x_shape4 = x.shape[4] + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape2, x_shape3 * x_shape4)) + x = xp.matrix_transpose(x) + x = xp.reshape(x, (x.shape[0], x.shape[1], x_shape3, x_shape4, x_shape2)) + return x + + @DescriptorBlock.register("se_repformer") @DescriptorBlock.register("se_uni") class DescrptBlockRepformers(NativeOP, DescriptorBlock): @@ -366,7 +391,7 @@ def call( ): xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) - nlist = xp.where(exclude_mask, nlist, -1) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -375,8 +400,8 @@ def call( # nf x nloc x nnei nlist_mask = nlist != -1 # nf x nloc x nnei - sw = sw.reshape(nf, nloc, nnei) - sw = xp.where(nlist_mask, sw, 0.0) + sw = xp.reshape(sw, (nf, nloc, nnei)) + sw = xp.where(nlist_mask, sw, xp.zeros_like(sw)) # nf x nloc x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] assert list(atype_embd.shape) == [nf, nloc, self.g1_dim] @@ -386,7 +411,7 @@ def call( if not self.direct_dist: g2, h2 = xp.split(dmatrix, [1], axis=-1) else: - g2, h2 = xp.linalg.norm(diff, axis=-1, keepdims=True), diff + g2, h2 = xp.linalg.vector_norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nf x nloc x nnei x ng2 @@ -395,11 +420,11 @@ def call( # if a neighbor is real or not is indicated by nlist_mask nlist = xp.where(nlist == -1, xp.zeros_like(nlist), nlist) # nf x nall x ng1 - mapping = xp.tile(mapping.reshape(nf, -1, 1), (1, 1, self.g1_dim)) + mapping = xp.tile(xp.reshape(mapping, (nf, -1, 1)), (1, 1, self.g1_dim)) for idx, ll in enumerate(self.layers): # g1: nf x nloc x ng1 # g1_ext: nf x nall x ng1 - g1_ext = xp.take_along_axis(g1, mapping, axis=1) + g1_ext = xp_take_along_axis(g1, mapping, axis=1) g1, g2, h2 = ll.call( g1_ext, g2, @@ -420,8 +445,9 @@ def call( use_sqrt_nnei=self.use_sqrt_nnei, ) # (nf x nloc) x ng2 x 3 - rot_mat = xp.transpose(h2g2, (0, 1, 3, 2)) - return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw + # rot_mat = xp.transpose(h2g2, (0, 1, 3, 2)) + rot_mat = xp.matrix_transpose(h2g2) + return g1, g2, h2, xp.reshape(rot_mat, (nf, nloc, self.dim_emb, 3)), sw def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" @@ -564,11 +590,11 @@ def _make_nei_g1( # g1_ext: nf x nall x ng1 ng1 = g1_ext.shape[-1] # index: nf x (nloc x nnei) x ng1 - index = xp.tile(nlist.reshape(nf, nloc * nnei, 1), (1, 1, ng1)) + index = xp.tile(xp.reshape(nlist, (nf, nloc * nnei, 1)), (1, 1, ng1)) # gg1 : nf x (nloc x nnei) x ng1 - gg1 = xp.take_along_axis(g1_ext, index, axis=1) + gg1 = xp_take_along_axis(g1_ext, index, axis=1) # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.reshape(nf, nloc, nnei, ng1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng1)) return gg1 @@ -587,7 +613,7 @@ def _apply_nlist_mask( Neighbor list mask, where zero means no neighbor, with shape [nf, nloc, nnei]. """ xp = array_api_compat.array_namespace(gg, nlist_mask) - masked_gg = xp.where(nlist_mask[:, :, :, None], gg, 0.0) + masked_gg = xp.where(nlist_mask[:, :, :, None], gg, xp.zeros_like(gg)) return masked_gg @@ -654,9 +680,11 @@ def _cal_hg( if not smooth: # nf x nloc if not use_sqrt_nnei: - invnnei = 1.0 / (epsilon + xp.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / (epsilon + xp.sum(xp.astype(nlist_mask, g.dtype), axis=-1)) else: - invnnei = 1.0 / (epsilon + xp.sqrt(xp.sum(nlist_mask, axis=-1))) + invnnei = 1.0 / ( + epsilon + xp.sqrt(xp.sum(xp.astype(nlist_mask, g.dtype), axis=-1)) + ) # nf x nloc x 1 x 1 invnnei = invnnei[:, :, xp.newaxis, xp.newaxis] else: @@ -668,7 +696,7 @@ def _cal_hg( (nf, nloc, 1, 1), dtype=g.dtype ) # nf x nloc x 3 x ng - hg = xp.matmul(xp.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei + hg = xp.matmul(xp.matrix_transpose(h), g) * invnnei return hg @@ -692,11 +720,11 @@ def _cal_grrg(hg: np.ndarray, axis_neuron: int) -> np.ndarray: # nf x nloc x 3 x ng nf, nloc, _, ng = hg.shape # nf x nloc x 3 x axis - hgm = xp.split(hg, [axis_neuron], axis=-1)[0] + hgm = hg[..., :axis_neuron] # nf x nloc x axis_neuron x ng - grrg = xp.matmul(xp.transpose(hgm, axes=(0, 1, 3, 2)), hg) / (3.0**1) + grrg = xp.matmul(xp.matrix_transpose(hgm), hg) / (3.0**1) # nf x nloc x (axis_neuron * ng) - grrg = grrg.reshape(nf, nloc, axis_neuron * ng) + grrg = xp.reshape(grrg, (nf, nloc, axis_neuron * ng)) return grrg @@ -802,19 +830,22 @@ def call( ) = g2.shape nd, nh = self.hidden_dim, self.head_num # nf x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).reshape(nf, nloc, nnei, nd, nh * 2) + g2qk = self.mapqk(g2) + g2qk = xp.reshape(g2qk, (nf, nloc, nnei, nd, nh * 2)) # nf x nloc x (nh x 2) x nnei x nd - g2qk = xp.transpose(g2qk, (0, 1, 4, 2, 3)) + # g2qk = xp.transpose(g2qk, (0, 1, 4, 2, 3)) + g2qk = xp_transpose_01423(g2qk) # nf x nloc x nh x nnei x nd - g2q, g2k = xp.split(g2qk, [nh], axis=2) + # g2q, g2k = xp.split(g2qk, [nh], axis=2) + g2q = g2qk[:, :, :nh, :, :] + g2k = g2qk[:, :, nh:, :, :] # g2q = np.linalg.norm(g2q, axis=-1) # g2k = np.linalg.norm(g2k, axis=-1) # nf x nloc x nh x nnei x nnei - attnw = xp.matmul(g2q, xp.transpose(g2k, axes=(0, 1, 2, 4, 3))) / nd**0.5 + attnw = xp.matmul(g2q, xp.matrix_transpose(g2k)) / nd**0.5 if self.has_gate: - gate = xp.matmul(h2, xp.transpose(h2, axes=(0, 1, 3, 2))).reshape( - nf, nloc, 1, nnei, nnei - ) + gate = xp.matmul(h2, xp.matrix_transpose(h2)) + gate = xp.reshape(gate, (nf, nloc, 1, nnei, nnei)) attnw = attnw * gate # mask the attenmap, nf x nloc x 1 x 1 x nnei attnw_mask = ~xp.expand_dims(xp.expand_dims(nlist_mask, axis=2), axis=2) @@ -825,20 +856,21 @@ def call( :, :, None, None, : ] - self.attnw_shift else: - attnw = xp.where(attnw_mask, -xp.inf, attnw) + attnw = xp.where(attnw_mask, xp.full_like(attnw, -xp.inf), attnw) attnw = np_softmax(attnw, axis=-1) - attnw = xp.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, xp.zeros_like(attnw), attnw) # nf x nloc x nh x nnei x nnei - attnw = xp.where(attnw_mask_c, 0.0, attnw) + attnw = xp.where(attnw_mask_c, xp.zeros_like(attnw), attnw) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] # nf x nloc x nnei x nnei - h2h2t = xp.matmul(h2, xp.transpose(h2, axes=(0, 1, 3, 2))) / 3.0**0.5 + h2h2t = xp.matmul(h2, xp.matrix_transpose(h2)) / 3.0**0.5 # nf x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = np.exp(g2qk - np.max(g2qk, axis=-1, keepdims=True)) # nf x nloc x nnei x nnei x nh - ret = xp.transpose(ret, (0, 1, 3, 4, 2)) + # ret = xp.transpose(ret, (0, 1, 3, 4, 2)) + ret = xp_transpose_01342(ret) return ret def serialize(self) -> dict: @@ -915,16 +947,18 @@ def call( nf, nloc, nnei, ng2 = g2.shape nh = self.head_num # nf x nloc x nnei x ng2 x nh - g2v = self.mapv(g2).reshape(nf, nloc, nnei, ng2, nh) + g2v = self.mapv(g2) + g2v = xp.reshape(g2v, (nf, nloc, nnei, ng2, nh)) # nf x nloc x nh x nnei x ng2 - g2v = xp.transpose(g2v, (0, 1, 4, 2, 3)) + g2v = xp_transpose_01423(g2v) # g2v = np.linalg.norm(g2v, axis=-1) # nf x nloc x nh x nnei x nnei - AA = xp.transpose(AA, (0, 1, 4, 2, 3)) + AA = xp_transpose_01423(AA) # nf x nloc x nh x nnei x ng2 ret = xp.matmul(AA, g2v) # nf x nloc x nnei x ng2 x nh - ret = xp.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, (ng2 * nh)) + ret = xp_transpose_01342(ret) + ret = xp.reshape(ret, (nf, nloc, nnei, (ng2 * nh))) # nf x nloc x nnei x ng2 return self.head_map(ret) @@ -991,14 +1025,15 @@ def call( nf, nloc, nnei, _ = h2.shape nh = self.head_num # nf x nloc x nh x nnei x nnei - AA = xp.transpose(AA, (0, 1, 4, 2, 3)) + AA = xp_transpose_01423(AA) h2m = xp.expand_dims(h2, axis=2) # nf x nloc x nh x nnei x 3 h2m = xp.tile(h2m, (1, 1, nh, 1, 1)) # nf x nloc x nh x nnei x 3 ret = xp.matmul(AA, h2m) # nf x nloc x nnei x 3 x nh - ret = xp.transpose(ret, (0, 1, 3, 4, 2)).reshape(nf, nloc, nnei, 3, nh) + ret = xp_transpose_01342(ret) + ret = xp.reshape(ret, (nf, nloc, nnei, 3, nh)) # nf x nloc x nnei x 3 return xp.squeeze(self.head_map(ret), axis=-1) @@ -1089,21 +1124,22 @@ def call( assert ni == g1.shape[-1] assert ni == gg1.shape[-1] # nf x nloc x nd x nh - g1q = self.mapq(g1).reshape(nf, nloc, nd, nh) + g1q = self.mapq(g1) + g1q = xp.reshape(g1q, (nf, nloc, nd, nh)) # nf x nloc x nh x nd - g1q = xp.transpose(g1q, (0, 1, 3, 2)) + g1q = xp.matrix_transpose(g1q) # nf x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).reshape(nf, nloc, nnei, nd + ni, nh) - gg1kv = xp.transpose(gg1kv, (0, 1, 4, 2, 3)) + gg1kv = self.mapkv(gg1) + gg1kv = xp.reshape(gg1kv, (nf, nloc, nnei, nd + ni, nh)) + gg1kv = xp_transpose_01423(gg1kv) # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 - gg1k, gg1v = xp.split(gg1kv, [nd], axis=-1) + # gg1k, gg1v = xp.split(gg1kv, [nd], axis=-1) + gg1k = gg1kv[:, :, :, :, :nd] + gg1v = gg1kv[:, :, :, :, nd:] # nf x nloc x nh x 1 x nnei attnw = ( - xp.matmul( - xp.expand_dims(g1q, axis=-2), xp.transpose(gg1k, axes=(0, 1, 2, 4, 3)) - ) - / nd**0.5 + xp.matmul(xp.expand_dims(g1q, axis=-2), xp.matrix_transpose(gg1k)) / nd**0.5 ) # nf x nloc x nh x nnei attnw = xp.squeeze(attnw, axis=-2) @@ -1115,18 +1151,16 @@ def call( sw, axis=-2 ) - self.attnw_shift else: - attnw = xp.where(attnw_mask, -xp.inf, attnw) + attnw = xp.where(attnw_mask, xp.full_like(attnw, -xp.inf), attnw) attnw = np_softmax(attnw, axis=-1) - attnw = xp.where(attnw_mask, 0.0, attnw) + attnw = xp.where(attnw_mask, xp.zeros_like(attnw), attnw) if self.smooth: attnw = attnw * xp.expand_dims(sw, axis=-2) # nf x nloc x nh x ng1 - ret = ( - xp.matmul(xp.expand_dims(attnw, axis=-2), gg1v) - .squeeze(-2) - .reshape(nf, nloc, nh * ni) - ) + ret = xp.matmul(xp.expand_dims(attnw, axis=-2), gg1v) + ret = xp.squeeze(ret, axis=-2) + ret = xp.reshape(ret, (nf, nloc, nh * ni)) # nf x nloc x ng1 ret = self.head_map(ret) return ret @@ -1498,16 +1532,19 @@ def _update_g1_conv( ng2 = g2.shape[-1] if not self.g1_out_conv: # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + gg1 = self.proj_g1g2(gg1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng2)) else: # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.reshape(nf, nloc, nnei, ng1) + gg1 = xp.reshape(gg1, (nf, nloc, nnei, ng1)) # nf x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth # nf x nloc - invnnei = 1.0 / (self.epsilon + xp.sum(nlist_mask, axis=-1)) + invnnei = 1.0 / ( + self.epsilon + xp.sum(xp.astype(nlist_mask, gg1.dtype), axis=-1) + ) # nf x nloc x 1 invnnei = invnnei[:, :, xp.newaxis] else: @@ -1518,7 +1555,8 @@ def _update_g1_conv( g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei else: # nf x nloc x ng1 - g2 = self.proj_g1g2(g2).reshape(nf, nloc, nnei, ng1) + g2 = self.proj_g1g2(g2) + g2 = xp.reshape(g2, (nf, nloc, nnei, ng1)) # nb x nloc x ng1 g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei return g1_11 @@ -1588,14 +1626,15 @@ def call( nf, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] - g1, _ = xp.split(g1_ext, [nloc], axis=1) + # g1, _ = xp.split(g1_ext, [nloc], axis=1) + g1 = g1_ext[:, :nloc, :] assert (nf, nloc) == g1.shape[:2] assert (nf, nloc, nnei) == h2.shape[:3] - g2_update: list[xp.ndarray] = [g2] - h2_update: list[xp.ndarray] = [h2] - g1_update: list[xp.ndarray] = [g1] - g1_mlp: list[xp.ndarray] = [g1] if not self.g1_out_mlp else [] + g2_update: list[np.ndarray] = [g2] + h2_update: list[np.ndarray] = [h2] + g1_update: list[np.ndarray] = [g1] + g1_mlp: list[np.ndarray] = [g1] if not self.g1_out_mlp else [] if self.g1_out_mlp: assert self.g1_self_mlp is not None g1_self_mlp = self.act(self.g1_self_mlp(g1)) @@ -1678,7 +1717,7 @@ def call( # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd - g1_1 = self.act(self.linear1(xp.concatenate(g1_mlp, axis=-1))) + g1_1 = self.act(self.linear1(xp.concat(g1_mlp, axis=-1))) g1_update.append(g1_1) if self.update_g1_has_attn: diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 68c9dd7a97..7b3b25df36 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -224,16 +224,17 @@ def build_multiple_neighbor_list( pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] - coord1 = coord.reshape(nb, -1, 3) + coord1 = xp.reshape(coord, (nb, -1, 3)) nall = coord1.shape[1] coord0 = coord1[:, :nloc, :] nlist_mask = nlist == -1 tnlist_0 = xp.where(nlist_mask, xp.zeros_like(nlist), nlist) - index = xp.tile(tnlist_0.reshape(nb, nloc * nsel, 1), [1, 1, 3]) - coord2 = xp.take_along_axis(coord1, index, axis=1).reshape(nb, nloc, nsel, 3) + index = xp.tile(xp.reshape(tnlist_0, (nb, nloc * nsel, 1)), (1, 1, 3)) + coord2 = xp_take_along_axis(coord1, index, axis=1) + coord2 = xp.reshape(coord2, (nb, nloc, nsel, 3)) diff = coord2 - coord0[:, :, None, :] - rr = xp.linalg.norm(diff, axis=-1) - rr = xp.where(nlist_mask, float("inf"), rr) + rr = xp.linalg.vector_norm(diff, axis=-1) + rr = xp.where(nlist_mask, xp.full_like(rr, float("inf")), rr) nlist0 = nlist ret = {} for rc, ns in zip(rcuts[::-1], nsels[::-1]): diff --git a/deepmd/jax/descriptor/dpa2.py b/deepmd/jax/descriptor/dpa2.py index 0e49689e94..6f4fe691e1 100644 --- a/deepmd/jax/descriptor/dpa2.py +++ b/deepmd/jax/descriptor/dpa2.py @@ -8,6 +8,7 @@ from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.jax.common import ( ArrayAPIVariable, + flax_module, to_jax_array, ) from deepmd.jax.descriptor.base_descriptor import ( @@ -31,6 +32,7 @@ @BaseDescriptor.register("dpa2") +@flax_module class DescrptDPA2(DescrptDPA2DP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: diff --git a/deepmd/jax/descriptor/repformers.py b/deepmd/jax/descriptor/repformers.py index 1d3a7fbb29..77ca4a9a6b 100644 --- a/deepmd/jax/descriptor/repformers.py +++ b/deepmd/jax/descriptor/repformers.py @@ -17,6 +17,7 @@ from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP from deepmd.jax.common import ( ArrayAPIVariable, + flax_module, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -28,6 +29,7 @@ ) +@flax_module class DescrptBlockRepformers(DescrptBlockRepformersDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: @@ -47,6 +49,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class Atten2Map(Atten2MapDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mapqk"}: @@ -54,6 +57,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mapv", "head_map"}: @@ -61,6 +65,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class Atten2EquiVarApply(Atten2EquiVarApplyDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"head_map"}: @@ -68,6 +73,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class LocalAtten(LocalAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mapq", "mapkv", "head_map"}: @@ -75,6 +81,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class RepformerLayer(RepformerLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: diff --git a/source/tests/array_api_strict/descriptor/dpa2.py b/source/tests/array_api_strict/descriptor/dpa2.py new file mode 100644 index 0000000000..ef2c83ef3d --- /dev/null +++ b/source/tests/array_api_strict/descriptor/dpa2.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.dpmodel.utils.network import Identity as IdentityDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.network import ( + NativeLayer, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) +from .base_descriptor import ( + BaseDescriptor, +) +from .dpa1 import ( + DescrptBlockSeAtten, +) +from .repformers import ( + DescrptBlockRepformers, +) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) + + +@BaseDescriptor.register("dpa2") +class DescrptDPA2(DescrptDPA2DP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"repinit"}: + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name in {"repinit_three_body"}: + if value is not None: + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name in {"repformers"}: + value = DescrptBlockRepformers.deserialize(value.serialize()) + elif name in {"type_embedding"}: + value = TypeEmbedNet.deserialize(value.serialize()) + elif name in {"g1_shape_tranform", "tebd_transform"}: + if value is None: + pass + elif isinstance(value, NativeLayerDP): + value = NativeLayer.deserialize(value.serialize()) + elif isinstance(value, IdentityDP): + value = IdentityDP.deserialize(value.serialize()) + else: + raise ValueError(f"Unknown layer type: {type(value)}") + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/repformers.py b/source/tests/array_api_strict/descriptor/repformers.py new file mode 100644 index 0000000000..ff65ff849f --- /dev/null +++ b/source/tests/array_api_strict/descriptor/repformers.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.repformers import ( + Atten2EquiVarApply as Atten2EquiVarApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import Atten2Map as Atten2MapDP +from deepmd.dpmodel.descriptor.repformers import ( + Atten2MultiHeadApply as Atten2MultiHeadApplyDP, +) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers as DescrptBlockRepformersDP, +) +from deepmd.dpmodel.descriptor.repformers import LocalAtten as LocalAttenDP +from deepmd.dpmodel.descriptor.repformers import RepformerLayer as RepformerLayerDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + LayerNorm, + NativeLayer, +) + + +class DescrptBlockRepformers(DescrptBlockRepformersDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"layers"}: + value = [RepformerLayer.deserialize(layer.serialize()) for layer in value] + elif name == "g2_embd": + value = NativeLayer.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class Atten2Map(Atten2MapDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapqk"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2MultiHeadApply(Atten2MultiHeadApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class Atten2EquiVarApply(Atten2EquiVarApplyDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class LocalAtten(LocalAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mapq", "mapkv", "head_map"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class RepformerLayer(RepformerLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"linear1", "linear2", "g1_self_mlp", "proj_g1g2", "proj_g1g1g2"}: + if value is not None: + value = NativeLayer.deserialize(value.serialize()) + elif name in {"g1_residual", "g2_residual", "h2_residual"}: + value = [to_array_api_strict_array(vv) for vv in value] + elif name in {"attn2g_map"}: + if value is not None: + value = Atten2Map.deserialize(value.serialize()) + elif name in {"attn2_mh_apply"}: + if value is not None: + value = Atten2MultiHeadApply.deserialize(value.serialize()) + elif name in {"attn2_lm"}: + if value is not None: + value = LayerNorm.deserialize(value.serialize()) + elif name in {"attn2_ev_apply"}: + if value is not None: + value = Atten2EquiVarApply.deserialize(value.serialize()) + elif name in {"loc_attn"}: + if value is not None: + value = LocalAtten.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_t_tebd.py b/source/tests/array_api_strict/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..12fc04e69e --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_t_tebd.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) + + +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class DescrptSeTTebd(DescrptSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_ttebd": + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index ff46b8296c..17c55db368 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -15,6 +15,7 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, CommonTest, @@ -33,6 +34,10 @@ from deepmd.jax.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2JAX else: DescrptDPA2JAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2Strict +else: + DescrptDPA2Strict = None # not implemented DescrptDPA2TF = None @@ -276,11 +281,13 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT jax_class = DescrptDPA2JAX + array_api_strict_class = DescrptDPA2Strict args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -386,6 +393,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)