diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e4af2ad627..360df78a7b 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Utilities for the array API.""" +import array_api_compat + def support_array_api(version: str) -> callable: """Mark a function as supporting the specific version of the array API. @@ -27,3 +29,41 @@ def set_version(func: callable) -> callable: return func return set_version + + +# array api adds take_along_axis in https://github.com/data-apis/array-api/pull/816 +# but it hasn't been released yet +# below is a pure Python implementation of take_along_axis +# https://github.com/data-apis/array-api/issues/177#issuecomment-2093630595 +def xp_swapaxes(a, axis1, axis2): + xp = array_api_compat.array_namespace(a) + axes = list(range(a.ndim)) + axes[axis1], axes[axis2] = axes[axis2], axes[axis1] + a = xp.permute_dims(a, axes) + return a + + +def xp_take_along_axis(arr, indices, axis): + xp = array_api_compat.array_namespace(arr) + arr = xp_swapaxes(arr, axis, -1) + indices = xp_swapaxes(indices, axis, -1) + + m = arr.shape[-1] + n = indices.shape[-1] + + shape = list(arr.shape) + shape.pop(-1) + shape = [*shape, n] + + arr = xp.reshape(arr, (-1,)) + if n != 0: + indices = xp.reshape(indices, (-1, n)) + else: + indices = xp.reshape(indices, (0, 0)) + + offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis] + indices = xp.reshape(offset + indices, (-1,)) + + out = xp.take(arr, indices) + out = xp.reshape(out, shape) + return xp_swapaxes(out, axis, -1) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 5ba3fc11b2..add9cb9f71 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -13,6 +14,9 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -32,9 +36,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -59,13 +60,16 @@ def np_softmax(x, axis=-1): - x = np.nan_to_num(x) # to avoid value warning - e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) - return e_x / np.sum(e_x, axis=axis, keepdims=True) + xp = array_api_compat.array_namespace(x) + # x = xp.nan_to_num(x) # to avoid value warning + x = xp.where(xp.isnan(x), xp.zeros_like(x), x) + e_x = xp.exp(x - xp.max(x, axis=axis, keepdims=True)) + return e_x / xp.sum(e_x, axis=axis, keepdims=True) def np_normalize(x, axis=-1): - return x / np.linalg.norm(x, axis=axis, keepdims=True) + xp = array_api_compat.array_namespace(x) + return x / xp.linalg.vector_norm(x, axis=axis, keepdims=True) @BaseDescriptor.register("se_atten") @@ -474,10 +478,14 @@ def call( The smooth switch function. """ del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) nf, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nf, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] + atype_embd_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nf, nall, self.tebd_dim), + ) # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] grrg, g2, h2, rot_mat, sw = self.se_atten( @@ -489,8 +497,8 @@ def call( ) # nf x nloc x (ng x ng1 + tebd_dim) if self.concat_output_tebd: - grrg = np.concatenate( - [grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1 + grrg = xp.concat( + [grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1 ) return grrg, rot_mat, None, None, sw @@ -536,8 +544,8 @@ def serialize(self) -> dict: "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { - "davg": obj["davg"], - "dstd": obj["dstd"], + "davg": np.array(obj["davg"]), + "dstd": np.array(obj["dstd"]), }, ## to be updated when the options are supported. "trainable": self.trainable, @@ -683,12 +691,12 @@ def __init__( self.embd_input_dim = 1 + self.tebd_dim_input else: self.embd_input_dim = 1 - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings[0] = EmbeddingNet( + embeddings[0] = EmbeddingNet( self.embd_input_dim, self.neuron, self.activation_function, @@ -696,13 +704,14 @@ def __init__( self.precision, seed=child_seed(seed, 0), ) + self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: - self.embeddings_strip = NetworkCollection( + embeddings_strip = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings_strip[0] = EmbeddingNet( + embeddings_strip[0] = EmbeddingNet( self.tebd_dim_input, self.neuron, self.activation_function, @@ -710,6 +719,7 @@ def __init__( self.precision, seed=child_seed(seed, 1), ) + self.embeddings_strip = embeddings_strip else: self.embeddings_strip = None self.dpa1_attention = NeighborGatedAttention( @@ -837,9 +847,10 @@ def cal_g( ss, embedding_idx, ): + xp = array_api_compat.array_namespace(ss) nfnl, nnei = ss.shape[0:2] - shape2 = np.prod(ss.shape[2:]) - ss = ss.reshape(nfnl, nnei, shape2) + shape2 = xp.prod(xp.asarray(ss.shape[2:])) + ss = xp.reshape(ss, (nfnl, nnei, shape2)) # nfnl x nnei x ng gg = self.embeddings[embedding_idx].call(ss) return gg @@ -850,9 +861,10 @@ def cal_g_strip( embedding_idx, ): assert self.embeddings_strip is not None + xp = array_api_compat.array_namespace(ss) nfnl, nnei = ss.shape[0:2] - shape2 = np.prod(ss.shape[2:]) - ss = ss.reshape(nfnl, nnei, shape2) + shape2 = xp.prod(xp.asarray(ss.shape[2:])) + ss = xp.reshape(ss, (nfnl, nnei, shape2)) # nfnl x nnei x ng gg = self.embeddings_strip[embedding_idx].call(ss) return gg @@ -865,6 +877,7 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -872,41 +885,42 @@ def call( nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei - nlist = nlist.reshape(nf * nloc, nnei) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.reshape(nlist, (nf * nloc, nnei)) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei x 4 - dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) # nfnl x nnei x 1 - sw = sw.reshape(nf * nloc, nnei, 1) + sw = xp.reshape(sw, (nf * nloc, nnei, 1)) # nfnl x tebd_dim - atype_embd = atype_embd_ext[:, :nloc, :].reshape(nf * nloc, self.tebd_dim) + atype_embd = xp.reshape(atype_embd_ext[:, :nloc, :], (nf * nloc, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nnei = np.tile(atype_embd[:, np.newaxis, :], (1, nnei, 1)) + atype_embd_nnei = xp.tile(atype_embd[:, xp.newaxis, :], (1, nnei, 1)) # nfnl x nnei nlist_mask = nlist != -1 # nfnl x nnei x 1 - sw = np.where(nlist_mask[:, :, None], sw, 0.0) - nlist_masked = np.where(nlist_mask, nlist, 0) - index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + sw = xp.where(nlist_mask[:, :, None], sw, xp.full_like(sw, 0.0)) + nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) + index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( - nf * nloc, nnei, self.tebd_dim + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) ) ng = self.neuron[-1] # nfnl x nnei x 4 - rr = dmatrix.reshape(nf * nloc, nnei, 4) - rr = rr * exclude_mask[:, :, None] + rr = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) + rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nnei x 1 ss = rr[..., 0:1] if self.tebd_input_mode in ["concat"]: if not self.type_one_side: # nfnl x nnei x (1 + 2 * tebd_dim) - ss = np.concatenate([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) else: # nfnl x nnei x (1 + tebd_dim) - ss = np.concatenate([ss, atype_embd_nlist], axis=-1) + ss = xp.concat([ss, atype_embd_nlist], axis=-1) # calculate gg # nfnl x nnei x ng gg = self.cal_g(ss, 0) @@ -916,42 +930,47 @@ def call( assert self.embeddings_strip is not None if not self.type_one_side: # nfnl x nnei x (tebd_dim * 2) - tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1) + tt = xp.concat([atype_embd_nlist, atype_embd_nnei], axis=-1) else: # nfnl x nnei x tebd_dim tt = atype_embd_nlist # nfnl x nnei x ng gg_t = self.cal_g_strip(tt, 0) if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1)) # nfnl x nnei x ng gg = gg_s * gg_t + gg_s else: raise NotImplementedError - input_r = rr.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum( - np.linalg.norm(rr.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True), - 1e-12, + normed = xp.linalg.vector_norm( + xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True + ) + input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( + normed, + xp.full_like(normed, 1e-12), ) gg = self.dpa1_attention( gg, nlist_mask, input_r=input_r, sw=sw ) # shape is [nframes*nloc, self.neei, out_size] # nfnl x ng x 4 - gr = np.einsum("lni,lnj->lij", gg, rr) + # gr = xp.einsum("lni,lnj->lij", gg, rr) + gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) gr /= self.nnei gr1 = gr[:, : self.axis_neuron, :] # nfnl x ng x ng1 - grrg = np.einsum("lid,ljd->lij", gr, gr1) + # grrg = xp.einsum("lid,ljd->lij", gr, gr1) + grrg = xp.sum(gr[:, :, None, :] * gr1[:, None, :, :], axis=3) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype( - GLOBAL_NP_FLOAT_PRECISION + grrg = xp.astype( + xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), coord_ext.dtype ) return ( - grrg.reshape(nf, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.reshape(nf, nloc, self.nnei, self.filter_neuron[-1]), - dmatrix.reshape(nf, nloc, self.nnei, 4)[..., 1:], - gr[..., 1:].reshape(nf, nloc, self.filter_neuron[-1], 3), - sw, + xp.reshape(grrg, (nf, nloc, self.filter_neuron[-1] * self.axis_neuron)), + xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])), + xp.reshape(dmatrix, (nf, nloc, self.nnei, 4))[..., 1:], + xp.reshape(gr[..., 1:], (nf, nloc, self.filter_neuron[-1], 3)), + xp.reshape(sw, (nf, nloc, nnei, 1)), ) def has_message_passing(self) -> bool: @@ -962,6 +981,77 @@ def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self + data = { + "@class": "DescriptorBlock", + "type": "dpa1", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "axis_neuron": obj.axis_neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "attn": obj.attn, + "attn_layer": obj.attn_layer, + "attn_dotr": obj.attn_dotr, + "attn_mask": obj.attn_mask, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + "scaling_factor": obj.scaling_factor, + "normalize": obj.normalize, + "temperature": obj.temperature, + "trainable_ln": obj.trainable_ln, + "ln_eps": obj.ln_eps, + "smooth": obj.smooth, + "type_one_side": obj.type_one_side, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "attention_layers": obj.dpa1_attention.serialize(), + "env_mat": obj.env_mat.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "@variables": { + "davg": np.array(obj["davg"]), + "dstd": np.array(obj["dstd"]), + }, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptDPA1": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + attention_layers = data.pop("attention_layers") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + obj = cls(**data) + + obj["davg"] = variables["davg"] + obj["dstd"] = variables["dstd"] + obj.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) + obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers) + return obj + class NeighborGatedAttention(NativeOP): def __init__( @@ -1254,18 +1344,23 @@ def __init__( ) def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): + xp = array_api_compat.array_namespace(query, nei_mask) # Linear projection - q, k, v = np.split(self.in_proj(query), 3, axis=-1) + # q, k, v = xp.split(self.in_proj(query), 3, axis=-1) + _query = self.in_proj(query) + q = _query[..., 0 : self.head_dim] + k = _query[..., self.head_dim : self.head_dim * 2] + v = _query[..., self.head_dim * 2 : self.head_dim * 3] # Reshape and normalize # (nf x nloc) x num_heads x nnei x head_dim - q = q.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + q = xp.permute_dims( + xp.reshape(q, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) - k = k.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + k = xp.permute_dims( + xp.reshape(k, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) - v = v.reshape(-1, self.nnei, self.num_heads, self.head_dim).transpose( - 0, 2, 1, 3 + v = xp.permute_dims( + xp.reshape(v, (-1, self.nnei, self.num_heads, self.head_dim)), (0, 2, 1, 3) ) if self.normalize: q = np_normalize(q, axis=-1) @@ -1274,29 +1369,38 @@ def call(self, query, nei_mask, input_r=None, sw=None, attnw_shift=20.0): q = q * self.scaling # Attention weights # (nf x nloc) x num_heads x nnei x nnei - attn_weights = q @ k.transpose(0, 1, 3, 2) - nei_mask = nei_mask.reshape(-1, self.nnei) + attn_weights = q @ xp.permute_dims(k, (0, 1, 3, 2)) + nei_mask = xp.reshape(nei_mask, (-1, self.nnei)) if self.smooth: - sw = sw.reshape(-1, 1, self.nnei) + sw = xp.reshape(sw, (-1, 1, self.nnei)) attn_weights = (attn_weights + attnw_shift) * sw[:, :, :, None] * sw[ :, :, None, : ] - attnw_shift else: - attn_weights = np.where(nei_mask[:, None, None, :], attn_weights, -np.inf) + attn_weights = xp.where( + nei_mask[:, None, None, :], + attn_weights, + xp.full_like(attn_weights, -xp.inf), + ) attn_weights = np_softmax(attn_weights, axis=-1) - attn_weights = np.where(nei_mask[:, None, :, None], attn_weights, 0.0) + attn_weights = xp.where( + nei_mask[:, None, :, None], attn_weights, xp.zeros_like(attn_weights) + ) if self.smooth: attn_weights = attn_weights * sw[:, :, :, None] * sw[:, :, None, :] if self.dotr: - angular_weight = (input_r @ input_r.transpose(0, 2, 1)).reshape( - -1, 1, self.nnei, self.nnei + angular_weight = xp.reshape( + input_r @ xp.permute_dims(input_r, (0, 2, 1)), + (-1, 1, self.nnei, self.nnei), ) attn_weights = attn_weights * angular_weight # Output projection # (nf x nloc) x num_heads x nnei x head_dim o = attn_weights @ v # (nf x nloc) x nnei x (num_heads x head_dim) - o = o.transpose(0, 2, 1, 3).reshape(-1, self.nnei, self.hidden_dim) + o = xp.reshape( + xp.permute_dims(o, (0, 2, 1, 3)), (-1, self.nnei, self.hidden_dim) + ) output = self.out_proj(o) return output, attn_weights diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 41f2591279..f4bc333a03 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -12,6 +12,7 @@ ) from deepmd.dpmodel.array_api import ( support_array_api, + xp_take_along_axis, ) @@ -44,33 +45,34 @@ def _make_env_mat( protection: float = 0.0, ): """Make smooth environment matrix.""" + xp = array_api_compat.array_namespace(nlist) nf, nloc, nnei = nlist.shape # nf x nall x 3 - coord = coord.reshape(nf, -1, 3) + coord = xp.reshape(coord, (nf, -1, 3)) mask = nlist >= 0 - nlist = nlist * mask + nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 - index = np.tile(nlist.reshape(nf, -1, 1), (1, 1, 3)) - coord_r = np.take_along_axis(coord, index, 1) + index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) + coord_r = xp_take_along_axis(coord, index, 1) # nf x nloc x nnei x 3 - coord_r = coord_r.reshape(nf, nloc, nnei, 3) + coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = coord[:, :nloc].reshape(nf, -1, 1, 3) + coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei - length = np.linalg.norm(diff, axis=-1, keepdims=True) + length = xp.linalg.vector_norm(diff, axis=-1, keepdims=True) # for index 0 nloc atom - length = length + ~np.expand_dims(mask, -1) + length = length + xp.astype(~xp.expand_dims(mask, axis=-1), length.dtype) t0 = 1 / (length + protection) t1 = diff / (length + protection) ** 2 weight = compute_smooth_weight(length, ruct_smth, rcut) - weight = weight * np.expand_dims(mask, -1) + weight = weight * xp.astype(xp.expand_dims(mask, axis=-1), weight.dtype) if radial_only: env_mat = t0 * weight else: - env_mat = np.concatenate([t0, t1], axis=-1) * weight - return env_mat, diff * np.expand_dims(mask, -1), weight + env_mat = xp.concat([t0, t1], axis=-1) * weight + return env_mat, diff * xp.astype(xp.expand_dims(mask, axis=-1), diff.dtype), weight class EnvMat(NativeOP): @@ -122,13 +124,14 @@ def call( switch The value of switch function. shape: nf x nloc x nnei """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape atype = atype_ext[:, :nloc] if davg is not None: - em -= davg[atype] + em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape) if dstd is not None: - em /= dstd[atype] + em /= xp.reshape(xp.take(dstd, xp.reshape(atype, (-1,)), axis=0), em.shape) return em, diff, sw def _call(self, nlist, coord_ext, radial_only): diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index d0a739b9d4..5469e66d97 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) + class AtomExcludeMask: """Computes the type exclusion mask for atoms.""" @@ -45,8 +50,9 @@ def build_type_exclude_mask( otherwise being 1. """ + xp = array_api_compat.array_namespace(atype) nf, natom = atype.shape - return self.type_mask[atype].reshape(nf, natom) + return xp.reshape(self.type_mask[atype], (nf, natom)) class PairExcludeMask: @@ -64,7 +70,7 @@ def __init__( self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) # ntypes + 1 for nlist masks - self.type_mask = np.array( + type_mask = np.array( [ [ 1 if (tt_i, tt_j) not in self.exclude_types else 0 @@ -75,7 +81,7 @@ def __init__( dtype=np.int32, ) # (ntypes+1 x ntypes+1) - self.type_mask = self.type_mask.reshape([-1]) + self.type_mask = type_mask.reshape([-1]) def get_exclude_types(self): return self.exclude_types @@ -102,23 +108,29 @@ def build_type_exclude_mask( otherwise being 1. """ + xp = array_api_compat.array_namespace(nlist, atype_ext) if len(self.exclude_types) == 0: # safely return 1 if nothing is excluded. - return np.ones_like(nlist, dtype=np.int32) + return xp.ones_like(nlist, dtype=xp.int32) nf, nloc, nnei = nlist.shape nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 - ae = np.concatenate( - [atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 + ae = xp.concat( + [atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 ) - type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1) + type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1) # nf x nloc x nnei - index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei) - type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + index = xp.reshape( + xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei) + ) + type_j = xp_take_along_axis(ae, index, axis=1) + type_j = xp.reshape(type_j, (nf, nloc, nnei)) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) - type_ij = type_ij.reshape(nf, nloc * nnei) - mask = self.type_mask[type_ij].reshape(nf, nloc, nnei) + type_ij = xp.reshape(type_ij, (nf, nloc * nnei)) + mask = xp.reshape( + xp.take(self.type_mask, xp.reshape(type_ij, (-1,))), (nf, nloc, nnei) + ) return mask def __contains__(self, item): diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index e1242c3669..339035ff4e 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -146,15 +146,18 @@ def deserialize(cls, data: dict) -> "NativeLayer": num_out, **data, ) - obj.w, obj.b, obj.idt = ( + w, b, idt = ( variables["w"], variables.get("b", None), variables.get("idt", None), ) - if obj.b is not None: - obj.b = obj.b.ravel() - if obj.idt is not None: - obj.idt = obj.idt.ravel() + if b is not None: + b = b.ravel() + if idt is not None: + idt = idt.ravel() + obj.w = w + obj.b = b + obj.idt = idt obj.check_shape_consistency() return obj @@ -175,8 +178,11 @@ def check_type_consistency(self): def check_var(var): if var is not None: + # array api standard doesn't provide a API to get the dtype name + # this is really hacked + dtype_name = str(var.dtype).split(".")[-1] # assertion "float64" == "double" would fail - assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] + assert PRECISION_DICT[dtype_name] is PRECISION_DICT[precision] check_var(self.w) check_var(self.b) @@ -249,7 +255,7 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += xp.concatenate([x, x], axis=-1) + y += xp.concat([x, x], axis=-1) return y @@ -360,10 +366,11 @@ def __init__( precision=precision, seed=seed, ) - self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] + xp = array_api_compat.array_namespace(self.w, self.b) + self.w = xp.squeeze(self.w, 0) # keep the weight shape to be [num_in] if self.uni_init: - self.w = np.ones_like(self.w) - self.b = np.zeros_like(self.b) + self.w = xp.ones_like(self.w) + self.b = xp.zeros_like(self.b) # only to keep consistent with other backends self.trainable = trainable @@ -376,8 +383,8 @@ def serialize(self) -> dict: The serialized layer. """ data = { - "w": self.w, - "b": self.b, + "w": to_numpy_array(self.w), + "b": to_numpy_array(self.b), } return { "@class": "LayerNorm", @@ -471,11 +478,12 @@ def call(self, x: np.ndarray) -> np.ndarray: @staticmethod def layer_norm_numpy(x, shape, weight=None, bias=None, eps=1e-5): + xp = array_api_compat.array_namespace(x) # mean and variance - mean = np.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) - var = np.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + mean = xp.mean(x, axis=tuple(range(-len(shape), 0)), keepdims=True) + var = xp.var(x, axis=tuple(range(-len(shape), 0)), keepdims=True) # normalize - x_normalized = (x - mean) / np.sqrt(var + eps) + x_normalized = (x - mean) / xp.sqrt(var + eps) # shift and scale if weight is not None and bias is not None: x_normalized = x_normalized * weight + bias diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index 4d0b3e3286..4806fa4cd8 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -4,8 +4,13 @@ Union, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) + from .region import ( normalize_coord, to_face_distance, @@ -88,34 +93,36 @@ def build_neighbor_list( For virtual atoms all neighboring positions are filled with -1. """ + xp = array_api_compat.array_namespace(coord, atype) batch_size = coord.shape[0] - coord = coord.reshape(batch_size, -1) + coord = xp.reshape(coord, (batch_size, -1)) nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. if coord.size > 0: - xmax = np.max(coord) + 2.0 * rcut + xmax = xp.max(coord) + 2.0 * rcut else: xmax = 2.0 * rcut # nf x nall is_vir = atype < 0 - coord1 = np.where( - is_vir[:, :, None], xmax, coord.reshape(batch_size, nall, 3) - ).reshape(batch_size, nall * 3) + coord1 = xp.where( + is_vir[:, :, None], xmax, xp.reshape(coord, (batch_size, nall, 3)) + ) + coord1 = xp.reshape(coord1, (batch_size, nall * 3)) if isinstance(sel, int): sel = [sel] nsel = sum(sel) coord0 = coord1[:, : nloc * 3] diff = ( - coord1.reshape([batch_size, -1, 3])[:, None, :, :] - - coord0.reshape([batch_size, -1, 3])[:, :, None, :] + xp.reshape(coord1, [batch_size, -1, 3])[:, None, :, :] + - xp.reshape(coord0, [batch_size, -1, 3])[:, :, None, :] ) assert list(diff.shape) == [batch_size, nloc, nall, 3] - rr = np.linalg.norm(diff, axis=-1) + rr = xp.linalg.vector_norm(diff, axis=-1) # if central atom has two zero distances, sorting sometimes can not exclude itself - rr -= np.eye(nloc, nall, dtype=diff.dtype)[np.newaxis, :, :] - nlist = np.argsort(rr, axis=-1) - rr = np.sort(rr, axis=-1) + rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :] + nlist = xp.argsort(rr, axis=-1) + rr = xp.sort(rr, axis=-1) rr = rr[:, :, 1:] nlist = nlist[:, :, 1:] nnei = rr.shape[2] @@ -123,16 +130,20 @@ def build_neighbor_list( rr = rr[:, :, :nsel] nlist = nlist[:, :, :nsel] else: - rr = np.concatenate( - [rr, np.ones([batch_size, nloc, nsel - nnei]) + rcut], # pylint: disable=no-explicit-dtype + rr = xp.concatenate( + [rr, xp.ones([batch_size, nloc, nsel - nnei]) + rcut], # pylint: disable=no-explicit-dtype axis=-1, ) - nlist = np.concatenate( - [nlist, np.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + nlist = xp.concatenate( + [nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], axis=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] - nlist = np.where(np.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist) + nlist = xp.where( + xp.logical_or((rr > rcut), is_vir[:, :nloc, None]), + xp.full_like(nlist, -1), + nlist, + ) if distinguish_types: return nlist_distinguish_types(nlist, atype, sel) @@ -149,23 +160,24 @@ def nlist_distinguish_types( distinguish atom types. """ + xp = array_api_compat.array_namespace(nlist, atype) nf, nloc, _ = nlist.shape ret_nlist = [] - tmp_atype = np.tile(atype[:, None], [1, nloc, 1]) + tmp_atype = xp.tile(atype[:, None], [1, nloc, 1]) mask = nlist == -1 tnlist_0 = nlist.copy() tnlist_0[mask] = 0 - tnlist = np.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() - tnlist = np.where(mask, -1, tnlist) + tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() + tnlist = xp.where(mask, -1, tnlist) snsel = tnlist.shape[2] for ii, ss in enumerate(sel): - pick_mask = (tnlist == ii).astype(np.int32) - sorted_indices = np.argsort(-pick_mask, kind="stable", axis=-1) - pick_mask_sorted = -np.sort(-pick_mask, axis=-1) - inlist = np.take_along_axis(nlist, sorted_indices, axis=2) - inlist = np.where(~pick_mask_sorted.astype(bool), -1, inlist) - ret_nlist.append(np.split(inlist, [ss, snsel - ss], axis=-1)[0]) - ret = np.concatenate(ret_nlist, axis=-1) + pick_mask = (tnlist == ii).astype(xp.int32) + sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -xp.sort(-pick_mask, axis=-1) + inlist = xp_take_along_axis(nlist, sorted_indices, axis=2) + inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist) + ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0]) + ret = xp.concat(ret_nlist, axis=-1) return ret @@ -263,36 +275,46 @@ def extend_coord_with_ghosts( maping extended index to the local index """ + xp = array_api_compat.array_namespace(coord, atype) nf, nloc = atype.shape - aidx = np.tile(np.arange(nloc)[np.newaxis, :], (nf, 1)) # pylint: disable=no-explicit-dtype + aidx = xp.tile(xp.arange(nloc)[xp.newaxis, :], (nf, 1)) # pylint: disable=no-explicit-dtype if cell is None: nall = nloc - extend_coord = coord.copy() - extend_atype = atype.copy() - extend_aidx = aidx.copy() + extend_coord = coord + extend_atype = atype + extend_aidx = aidx else: - coord = coord.reshape((nf, nloc, 3)) - cell = cell.reshape((nf, 3, 3)) + coord = xp.reshape(coord, (nf, nloc, 3)) + cell = xp.reshape(cell, (nf, 3, 3)) to_face = to_face_distance(cell) - nbuff = np.ceil(rcut / to_face).astype(int) - nbuff = np.max(nbuff, axis=0) - xi = np.arange(-nbuff[0], nbuff[0] + 1, 1) # pylint: disable=no-explicit-dtype - yi = np.arange(-nbuff[1], nbuff[1] + 1, 1) # pylint: disable=no-explicit-dtype - zi = np.arange(-nbuff[2], nbuff[2] + 1, 1) # pylint: disable=no-explicit-dtype - xyz = np.outer(xi, np.array([1, 0, 0]))[:, np.newaxis, np.newaxis, :] - xyz = xyz + np.outer(yi, np.array([0, 1, 0]))[np.newaxis, :, np.newaxis, :] - xyz = xyz + np.outer(zi, np.array([0, 0, 1]))[np.newaxis, np.newaxis, :, :] - xyz = xyz.reshape(-1, 3) - shift_idx = xyz[np.argsort(np.linalg.norm(xyz, axis=1))] + nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64) + nbuff = xp.max(nbuff, axis=0) + xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1) # pylint: disable=no-explicit-dtype + yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1) # pylint: disable=no-explicit-dtype + zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1) # pylint: disable=no-explicit-dtype + xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :] + xyz = ( + xyz + + xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :] + ) + xyz = ( + xyz + + xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :] + ) + xyz = xp.reshape(xyz, (-1, 3)) + xyz = xp.astype(xyz, coord.dtype) + shift_idx = xp.take(xyz, xp.argsort(xp.linalg.vector_norm(xyz, axis=1)), axis=0) ns, _ = shift_idx.shape nall = ns * nloc - shift_vec = np.einsum("sd,fdk->fsk", shift_idx, cell) + # shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell) + shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1])) + shift_vec = xp.permute_dims(shift_vec, (1, 0, 2)) extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] - extend_atype = np.tile(atype[:, :, np.newaxis], (1, ns, 1)) - extend_aidx = np.tile(aidx[:, :, np.newaxis], (1, ns, 1)) + extend_atype = xp.tile(atype[:, :, xp.newaxis], (1, ns, 1)) + extend_aidx = xp.tile(aidx[:, :, xp.newaxis], (1, ns, 1)) return ( - extend_coord.reshape((nf, nall * 3)), - extend_atype.reshape((nf, nall)), - extend_aidx.reshape((nf, nall)), + xp.reshape(extend_coord, (nf, nall * 3)), + xp.reshape(extend_atype, (nf, nall)), + xp.reshape(extend_aidx, (nf, nall)), ) diff --git a/deepmd/dpmodel/utils/region.py b/deepmd/dpmodel/utils/region.py index ddbc4b29b8..8102020827 100644 --- a/deepmd/dpmodel/utils/region.py +++ b/deepmd/dpmodel/utils/region.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np @@ -21,8 +22,9 @@ def phys2inter( the internal coordinates """ - rec_cell = np.linalg.inv(cell) - return np.matmul(coord, rec_cell) + xp = array_api_compat.array_namespace(coord, cell) + rec_cell = xp.linalg.inv(cell) + return xp.matmul(coord, rec_cell) def inter2phys( @@ -44,7 +46,8 @@ def inter2phys( the physical coordinates """ - return np.matmul(coord, cell) + xp = array_api_compat.array_namespace(coord, cell) + return xp.matmul(coord, cell) def normalize_coord( @@ -66,8 +69,9 @@ def normalize_coord( wrapped coordinates of shape [*, na, 3]. """ + xp = array_api_compat.array_namespace(coord, cell) icoord = phys2inter(coord, cell) - icoord = np.remainder(icoord, 1.0) + icoord = xp.remainder(icoord, 1.0) return inter2phys(icoord, cell) @@ -87,17 +91,19 @@ def to_face_distance( the to face distances of shape [*, 3] """ + xp = array_api_compat.array_namespace(cell) cshape = cell.shape - dist = b_to_face_distance(cell.reshape([-1, 3, 3])) - return dist.reshape(list(cshape[:-2]) + [3]) # noqa:RUF005 + dist = b_to_face_distance(xp.reshape(cell, [-1, 3, 3])) + return xp.reshape(dist, list(cshape[:-2]) + [3]) # noqa:RUF005 def b_to_face_distance(cell): - volume = np.linalg.det(cell) - c_yz = np.cross(cell[:, 1], cell[:, 2], axis=-1) - _h2yz = volume / np.linalg.norm(c_yz, axis=-1) - c_zx = np.cross(cell[:, 2], cell[:, 0], axis=-1) - _h2zx = volume / np.linalg.norm(c_zx, axis=-1) - c_xy = np.cross(cell[:, 0], cell[:, 1], axis=-1) - _h2xy = volume / np.linalg.norm(c_xy, axis=-1) - return np.stack([_h2yz, _h2zx, _h2xy], axis=1) + xp = array_api_compat.array_namespace(cell) + volume = xp.linalg.det(cell) + c_yz = xp.linalg.cross(cell[:, 1, ...], cell[:, 2, ...], axis=-1) + _h2yz = volume / xp.linalg.vector_norm(c_yz, axis=-1) + c_zx = xp.linalg.cross(cell[:, 2, ...], cell[:, 0, ...], axis=-1) + _h2zx = volume / xp.linalg.vector_norm(c_zx, axis=-1) + c_xy = xp.linalg.cross(cell[:, 0, ...], cell[:, 1, ...], axis=-1) + _h2xy = volume / xp.linalg.vector_norm(c_xy, axis=-1) + return xp.stack([_h2yz, _h2zx, _h2xy], axis=1) diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index d67d8e50fd..e28b6abb31 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -106,7 +106,7 @@ def call(self) -> np.ndarray: embed = self.embedding_net(self.econf_tebd) if self.padding: embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) - embed = xp.concatenate([embed, embed_pad], axis=0) + embed = xp.concat([embed, embed_pad], axis=0) return embed @classmethod diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 550b168b29..9c144a41d1 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Union, + Any, + Optional, overload, ) import numpy as np +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.jax.env import ( jnp, + nnx, ) @@ -19,7 +24,7 @@ def to_jax_array(array: np.ndarray) -> jnp.ndarray: ... def to_jax_array(array: None) -> None: ... -def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]: +def to_jax_array(array: Optional[np.ndarray]) -> Optional[jnp.ndarray]: """Convert a numpy array to a JAX array. Parameters @@ -35,3 +40,44 @@ def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]: if array is None: return None return jnp.array(array) + + +def flax_module( + module: NativeOP, +) -> nnx.Module: + """Convert a NativeOP to a Flax module. + + Parameters + ---------- + module : NativeOP + The NativeOP to convert. + + Returns + ------- + flax.nnx.Module + The Flax module. + + Examples + -------- + >>> @flax_module + ... class MyModule(NativeOP): + ... pass + """ + metas = set() + if not issubclass(type(nnx.Module), type(module)): + metas.add(type(module)) + if not issubclass(type(module), type(nnx.Module)): + metas.add(type(nnx.Module)) + + class MixedMetaClass(*metas): + def __call__(self, *args, **kwargs): + return type(nnx.Module).__call__(self, *args, **kwargs) + + class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): + def __init_subclass__(cls, **kwargs) -> None: + return super().__init_subclass__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + return super().__setattr__(name, value) + + return FlaxModule diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py new file mode 100644 index 0000000000..a9b0404970 --- /dev/null +++ b/deepmd/jax/descriptor/dpa1.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttention as NeighborGatedAttentionDP, +) +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, +) +from deepmd.jax.common import ( + flax_module, + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + LayerNorm, + NativeLayer, + NetworkCollection, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@flax_module +class GatedAttentionLayer(GatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"in_proj", "out_proj"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layer": + value = GatedAttentionLayer.deserialize(value.serialize()) + elif name == "attn_layer_norm": + value = LayerNorm.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +@flax_module +class NeighborGatedAttention(NeighborGatedAttentionDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layers": + value = [ + NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value + ] + return super().__setattr__(name, value) + + +@flax_module +class DescrptBlockSeAtten(DescrptBlockSeAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "dpa1_attention": + value = NeighborGatedAttention.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +@flax_module +class DescrptDPA1(DescrptDPA1DP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_atten": + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 34e4aa6240..5a5a7f6bf0 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -5,10 +5,14 @@ import jax import jax.numpy as jnp +from flax import ( + nnx, +) jax.config.update("jax_enable_x64", True) __all__ = [ "jax", "jnp", + "nnx", ] diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py new file mode 100644 index 0000000000..cac4cee092 --- /dev/null +++ b/deepmd/jax/utils/exclude_mask.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP +from deepmd.jax.common import ( + flax_module, + to_jax_array, +) + + +@flax_module +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_jax_array(value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 629b51b8cd..2c406095cd 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -1,29 +1,74 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, + ClassVar, ) from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( make_embedding_network, make_fitting_network, make_multilayer_network, ) from deepmd.jax.common import ( + flax_module, to_jax_array, ) +from deepmd.jax.env import ( + nnx, +) + + +class ArrayAPIParam(nnx.Param): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + + def __array_namespace__(self, *args, **kwargs): + return self.value.__array_namespace__(*args, **kwargs) + def __dlpack__(self, *args, **kwargs): + return self.value.__dlpack__(*args, **kwargs) + def __dlpack_device__(self, *args, **kwargs): + return self.value.__dlpack_device__(*args, **kwargs) + + +@flax_module class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIParam(value) return super().__setattr__(name, value) -NativeNet = make_multilayer_network(NativeLayer, NativeOP) -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +@flax_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): + pass + + +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass + + +@flax_module +class NetworkCollection(NetworkCollectionDP): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index bc7c469524..3143460244 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + flax_module, to_jax_array, ) from deepmd.jax.utils.network import ( @@ -12,6 +13,7 @@ ) +@flax_module class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: diff --git a/pyproject.toml b/pyproject.toml index 6932960ace..b13dceeb07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,6 +137,7 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', + 'flax>=0.8.0;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts] diff --git a/source/tests/array_api_strict/__init__.py b/source/tests/array_api_strict/__init__.py new file mode 100644 index 0000000000..27785c2fd5 --- /dev/null +++ b/source/tests/array_api_strict/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Synchronize with deepmd.jax for test purpose only.""" diff --git a/source/tests/array_api_strict/common.py b/source/tests/array_api_strict/common.py new file mode 100644 index 0000000000..28f67a97f6 --- /dev/null +++ b/source/tests/array_api_strict/common.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +import array_api_strict +import numpy as np + + +def to_array_api_strict_array(array: Optional[np.ndarray]): + """Convert a numpy array to a JAX array. + + Parameters + ---------- + array : np.ndarray + The numpy array to convert. + + Returns + ------- + jnp.ndarray + The JAX tensor. + """ + if array is None: + return None + return array_api_strict.asarray(array) diff --git a/source/tests/array_api_strict/descriptor/__init__.py b/source/tests/array_api_strict/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/array_api_strict/descriptor/dpa1.py b/source/tests/array_api_strict/descriptor/dpa1.py new file mode 100644 index 0000000000..ebd688e303 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/dpa1.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptBlockSeAtten as DescrptBlockSeAttenDP +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.descriptor.dpa1 import GatedAttentionLayer as GatedAttentionLayerDP +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttention as NeighborGatedAttentionDP, +) +from deepmd.dpmodel.descriptor.dpa1 import ( + NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, +) + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + LayerNorm, + NativeLayer, + NetworkCollection, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) + + +class GatedAttentionLayer(GatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"in_proj", "out_proj"}: + value = NativeLayer.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layer": + value = GatedAttentionLayer.deserialize(value.serialize()) + elif name == "attn_layer_norm": + value = LayerNorm.deserialize(value.serialize()) + return super().__setattr__(name, value) + + +class NeighborGatedAttention(NeighborGatedAttentionDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "attention_layers": + value = [ + NeighborGatedAttentionLayer.deserialize(ii.serialize()) for ii in value + ] + return super().__setattr__(name, value) + + +class DescrptBlockSeAtten(DescrptBlockSeAttenDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "dpa1_attention": + value = NeighborGatedAttention.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class DescrptDPA1(DescrptDPA1DP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_atten": + value = DescrptBlockSeAtten.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/utils/__init__.py b/source/tests/array_api_strict/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/array_api_strict/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/array_api_strict/utils/exclude_mask.py b/source/tests/array_api_strict/utils/exclude_mask.py new file mode 100644 index 0000000000..06f2e94b52 --- /dev/null +++ b/source/tests/array_api_strict/utils/exclude_mask.py @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP + +from ..common import ( + to_array_api_strict_array, +) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"type_mask"}: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/utils/network.py b/source/tests/array_api_strict/utils/network.py new file mode 100644 index 0000000000..42b0bb5c61 --- /dev/null +++ b/source/tests/array_api_strict/utils/network.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, +) + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) + +from ..common import ( + to_array_api_strict_array, +) + + +class NativeLayer(NativeLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"}: + value = to_array_api_strict_array(value) + return super().__setattr__(name, value) + + +NativeNet = make_multilayer_network(NativeLayer, NativeOP) +EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) + + +class NetworkCollection(NetworkCollectionDP): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/array_api_strict/utils/type_embed.py b/source/tests/array_api_strict/utils/type_embed.py new file mode 100644 index 0000000000..7551279002 --- /dev/null +++ b/source/tests/array_api_strict/utils/type_embed.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.network import ( + EmbeddingNet, +) + + +class TypeEmbedNet(TypeEmbedNetDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"econf_tebd"}: + value = to_array_api_strict_array(value) + if name in {"embedding_net"}: + value = EmbeddingNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/common/dpmodel/test_descriptor_dpa1.py b/source/tests/common/dpmodel/test_descriptor_dpa1.py index 317f4c3d3d..f441895f15 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa1.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa1.py @@ -36,3 +36,22 @@ def test_self_consistency( mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist) for ii in [0, 1, 4]: np.testing.assert_allclose(mm0[ii], mm1[ii]) + + def test_multiple_frames(self): + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + em0 = DescrptDPA1(self.rcut, self.rcut_smth, self.sel, ntypes=2) + em0.davg = davg + em0.dstd = dstd + two_coord_ext = np.concatenate([self.coord_ext, self.coord_ext], axis=0) + two_atype_ext = np.concatenate([self.atype_ext, self.atype_ext], axis=0) + two_nlist = np.concatenate([self.nlist, self.nlist], axis=0) + + mm0 = em0.call(two_coord_ext, two_atype_ext, two_nlist) + for ii in [0, 1, 4]: + np.testing.assert_allclose(mm0[ii][0], mm0[ii][2], err_msg=f"{ii} 0~2") + np.testing.assert_allclose(mm0[ii][1], mm0[ii][3], err_msg=f"{ii} 1~3") diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index c64b14c273..1070fe0f79 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -10,6 +10,9 @@ from enum import ( Enum, ) +from importlib.util import ( + find_spec, +) from typing import ( Any, Callable, @@ -33,6 +36,7 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() +INSTALLED_ARRAY_API_STRICT = find_spec("array_api_strict") is not None if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): raise ImportError("TensorFlow or PyTorch should be tested in the CI") @@ -56,6 +60,7 @@ "INSTALLED_TF", "INSTALLED_PT", "INSTALLED_JAX", + "INSTALLED_ARRAY_API_STRICT", ] @@ -72,6 +77,7 @@ class CommonTest(ABC): """PyTorch model class.""" jax_class: ClassVar[Optional[type]] """JAX model class.""" + array_api_strict_class: ClassVar[Optional[type]] args: ClassVar[Optional[Union[Argument, list[Argument]]]] """Arguments that maps to the `data`.""" skip_dp: ClassVar[bool] = False @@ -83,6 +89,8 @@ class CommonTest(ABC): # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" + skip_array_api_strict: ClassVar[bool] = True + """Whether to skip the array_api_strict model.""" rtol = 1e-10 """Relative tolerance for comparing the return value. Override for float32.""" atol = 1e-10 @@ -163,6 +171,16 @@ def eval_jax(self, jax_obj: Any) -> Any: """ raise NotImplementedError("Not implemented") + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + """Evaluate the return value of array_api_strict. + + Parameters + ---------- + array_api_strict_obj : Any + The object of array_api_strict + """ + raise NotImplementedError("Not implemented") + class RefBackend(Enum): """Reference backend.""" @@ -170,6 +188,7 @@ class RefBackend(Enum): DP = 2 PT = 3 JAX = 5 + ARRAY_API_STRICT = 6 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: @@ -235,6 +254,11 @@ def get_jax_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_array_api_strict_ret_serialization_from_cls(self, obj): + ret = self.eval_array_api_strict(obj) + data = obj.serialize() + return ret, data + def get_reference_backend(self): """Get the reference backend. @@ -248,6 +272,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_jax: return self.RefBackend.JAX + if not self.skip_array_api_strict: + return self.RefBackend.ARRAY_API_STRICT raise ValueError("No available reference") def get_reference_ret_serialization(self, ref: RefBackend): @@ -261,6 +287,12 @@ def get_reference_ret_serialization(self, ref: RefBackend): if ref == self.RefBackend.PT: obj = self.init_backend_cls(self.pt_class) return self.get_pt_ret_serialization_from_cls(obj) + if ref == self.RefBackend.JAX: + obj = self.init_backend_cls(self.jax_class) + return self.get_jax_ret_serialization_from_cls(obj) + if ref == self.RefBackend.ARRAY_API_STRICT: + obj = self.init_backend_cls(self.array_api_strict_class) + return self.get_array_api_strict_ret_serialization_from_cls(obj) raise ValueError("No available reference") def test_tf_consistent_with_ref(self): @@ -415,6 +447,40 @@ def test_jax_self_consistent(self): else: self.assertEqual(rr1, rr2) + def test_array_api_strict_consistent_with_ref(self): + """Test whether array_api_strict and reference are consistent.""" + if self.skip_array_api_strict: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.ARRAY_API_STRICT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + array_api_strict_obj = self.array_api_strict_class.deserialize(data1) + ret2 = self.eval_array_api_strict(array_api_strict_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.ARRAY_API_STRICT) + data2 = array_api_strict_obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_array_api_strict_self_consistent(self): + """Test whether array_api_strict is self consistent.""" + if self.skip_array_api_strict: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.array_api_strict_class) + ret1, data1 = self.get_array_api_strict_ret_serialization_from_cls(obj1) + obj1 = self.array_api_strict_class.deserialize(data1) + ret2, data2 = self.get_array_api_strict_ret_serialization_from_cls(obj1) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def tearDown(self) -> None: """Clear the TF session.""" if not self.skip_tf: diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 74fc3d9b07..e0ca30c799 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -3,6 +3,8 @@ Any, ) +import numpy as np + from deepmd.common import ( make_default_mesh, ) @@ -12,6 +14,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -29,6 +33,12 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict class DescriptorTest: @@ -99,3 +109,56 @@ def eval_pt_descriptor( x.detach().cpu().numpy() if torch.is_tensor(x) else x for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] + + def eval_jax_descriptor( + self, jax_obj: Any, natoms, coords, atype, box, mixed_types: bool = False + ) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + jnp.array(coords).reshape(1, -1, 3), + jnp.array(atype).reshape(1, -1), + jnp.array(box).reshape(1, 3, 3), + jax_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + jax_obj.get_rcut(), + jax_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + np.asarray(x) if isinstance(x, jnp.ndarray) else x + for x in jax_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + ] + + def eval_array_api_strict_descriptor( + self, + array_api_strict_obj: Any, + natoms, + coords, + atype, + box, + mixed_types: bool = False, + ) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + array_api_strict.asarray(coords.reshape(1, -1, 3)), + array_api_strict.asarray(atype.reshape(1, -1)), + array_api_strict.asarray(box.reshape(1, 3, 3)), + array_api_strict_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + array_api_strict_obj.get_rcut(), + array_api_strict_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + np.asarray(x) if hasattr(x, "__array_namespace__") else x + for x in array_api_strict_obj( + ext_coords, ext_atype, nlist=nlist, mapping=mapping + ) + ] diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 59d7369753..ed7884adb9 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -16,6 +16,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -33,6 +35,14 @@ from deepmd.tf.descriptor.se_atten import DescrptDPA1Compat as DescrptDPA1TF else: DescrptDPA1TF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1JAX +else: + DescriptorDPA1JAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict +else: + DescriptorDPA1Strict = None from deepmd.utils.argcheck import ( descrpt_se_atten_args, ) @@ -183,6 +193,69 @@ def skip_dp(self) -> bool: temperature, ) + @property + def skip_jax(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_array_api_strict(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return ( + not INSTALLED_ARRAY_API_STRICT + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + ) + @property def skip_tf(self) -> bool: ( @@ -226,6 +299,9 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA1TF dp_class = DescrptDPA1DP pt_class = DescrptDPA1PT + jax_class = DescriptorDPA1JAX + array_api_strict_class = DescriptorDPA1Strict + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -313,6 +389,26 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 1464517581..e2836c7a6c 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -12,6 +12,7 @@ ) from .common import ( + INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, @@ -37,6 +38,10 @@ from deepmd.jax.utils.type_embed import TypeEmbedNet as TypeEmbedNetJAX else: TypeEmbedNetJAX = object +if INSTALLED_ARRAY_API_STRICT: + from ..array_api_strict.utils.type_embed import TypeEmbedNet as TypeEmbedNetStrict +else: + TypeEmbedNetStrict = None @parameterized( @@ -71,8 +76,10 @@ def data(self) -> dict: dp_class = TypeEmbedNetDP pt_class = TypeEmbedNetPT jax_class = TypeEmbedNetJAX + array_api_strict_class = TypeEmbedNetStrict args = type_embedding_args() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property def addtional_data(self) -> dict: @@ -120,6 +127,12 @@ def eval_jax(self, jax_obj: Any) -> Any: raise ValueError("Output is numpy array") return [np.array(x) if isinstance(x, jnp.ndarray) else x for x in (out,)] + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + out = array_api_strict_obj() + return [ + np.asarray(x) if hasattr(x, "__array_namespace__") else x for x in (out,) + ] + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)