diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index ca89c23968..298f823690 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -5,12 +5,20 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + xp_take_along_axis, +) +from deepmd.dpmodel.common import ( + get_xp_precision, + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -26,9 +34,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -318,11 +323,15 @@ def call( sw The smooth switch function. """ + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) del mapping nf, nloc, nnei = nlist.shape - nall = coord_ext.reshape(nf, -1).shape[1] // 3 + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 # nf x nall x tebd_dim - atype_embd_ext = self.type_embedding.call()[atype_ext] + atype_embd_ext = xp.reshape( + xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0), + (nf, nall, self.tebd_dim), + ) # nfnl x tebd_dim atype_embd = atype_embd_ext[:, :nloc, :] grrg, g2, h2, rot_mat, sw = self.se_ttebd( @@ -334,8 +343,8 @@ def call( ) # nf x nloc x (ng + tebd_dim) if self.concat_output_tebd: - grrg = np.concatenate( - [grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1 + grrg = xp.concat( + [grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1 ) return grrg, rot_mat, None, None, sw @@ -368,8 +377,8 @@ def serialize(self) -> dict: "env_protection": obj.env_protection, "smooth": self.smooth, "@variables": { - "davg": obj["davg"], - "dstd": obj["dstd"], + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), }, "trainable": self.trainable, } @@ -491,12 +500,12 @@ def __init__( else: self.embd_input_dim = 1 - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings[0] = EmbeddingNet( + embeddings[0] = EmbeddingNet( self.embd_input_dim, self.neuron, self.activation_function, @@ -504,13 +513,14 @@ def __init__( self.precision, seed=child_seed(seed, 0), ) + self.embeddings = embeddings if self.tebd_input_mode in ["strip"]: - self.embeddings_strip = NetworkCollection( + embeddings_strip = NetworkCollection( ndim=0, ntypes=self.ntypes, network_type="embedding_network", ) - self.embeddings_strip[0] = EmbeddingNet( + embeddings_strip[0] = EmbeddingNet( self.tebd_dim_input, self.neuron, self.activation_function, @@ -518,6 +528,7 @@ def __init__( self.precision, seed=child_seed(seed, 1), ) + self.embeddings_strip = embeddings_strip else: self.embeddings_strip = None @@ -652,6 +663,7 @@ def call( atype_embd_ext: Optional[np.ndarray] = None, mapping: Optional[np.ndarray] = None, ): + xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext) # nf x nloc x nnei x 4 dmatrix, diff, sw = self.env_mat.call( coord_ext, atype_ext, nlist, self.mean, self.stddev @@ -659,47 +671,49 @@ def call( nf, nloc, nnei, _ = dmatrix.shape exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # nfnl x nnei - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) # nfnl x nnei - nlist = nlist.reshape(nf * nloc, nnei) - nlist = np.where(exclude_mask, nlist, -1) + nlist = xp.reshape(nlist, (nf * nloc, nnei)) + nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1)) # nfnl x nnei nlist_mask = nlist != -1 # nfnl x nnei x 1 - sw = np.where(nlist_mask[:, :, None], sw.reshape(nf * nloc, nnei, 1), 0.0) + sw = xp.where( + nlist_mask[:, :, None], + xp.reshape(sw, (nf * nloc, nnei, 1)), + xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype), + ) # nfnl x nnei x 4 - dmatrix = dmatrix.reshape(nf * nloc, nnei, 4) + dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4)) # nfnl x nnei x 4 rr = dmatrix - rr = rr * exclude_mask[:, :, None] + rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nt_i x 3 rr_i = rr[:, :, 1:] # nfnl x nt_j x 3 rr_j = rr[:, :, 1:] # nfnl x nt_i x nt_j - env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1) # nfnl x nt_i x nt_j x 1 - ss = np.expand_dims(env_ij, axis=-1) + ss = env_ij[..., None] - nlist_masked = np.where(nlist_mask, nlist, 0) - index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim)) + nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) + index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim)) # nfnl x nnei x tebd_dim - atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape( - nf * nloc, nnei, self.tebd_dim + atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1) + atype_embd_nlist = xp.reshape( + atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim) ) # nfnl x nt_i x nt_j x tebd_dim - nlist_tebd_i = np.tile( - np.expand_dims(atype_embd_nlist, axis=2), [1, 1, self.nnei, 1] - ) - nlist_tebd_j = np.tile( - np.expand_dims(atype_embd_nlist, axis=1), [1, self.nnei, 1, 1] - ) + nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1)) + nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1)) ng = self.neuron[-1] if self.tebd_input_mode in ["concat"]: # nfnl x nt_i x nt_j x (1 + tebd_dim * 2) - ss = np.concatenate([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) + ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1) # nfnl x nt_i x nt_j x ng gg = self.cal_g(ss, 0) elif self.tebd_input_mode in ["strip"]: @@ -707,14 +721,14 @@ def call( gg_s = self.cal_g(ss, 0) assert self.embeddings_strip is not None # nfnl x nt_i x nt_j x (tebd_dim * 2) - tt = np.concatenate([nlist_tebd_i, nlist_tebd_j], axis=-1) + tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1) # nfnl x nt_i x nt_j x ng gg_t = self.cal_g_strip(tt, 0) if self.smooth: gg_t = ( gg_t - * sw.reshape(nf * nloc, self.nnei, 1, 1) - * sw.reshape(nf * nloc, 1, self.nnei, 1) + * xp.reshape(sw, (nf * nloc, self.nnei, 1, 1)) + * xp.reshape(sw, (nf * nloc, 1, self.nnei, 1)) ) # nfnl x nt_i x nt_j x ng gg = gg_s * gg_t + gg_s @@ -722,12 +736,12 @@ def call( raise NotImplementedError # nfnl x ng - res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = xp.sum(env_ij[:, :, :, None] * gg[:, :, :, :], axis=(1, 2)) res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei)) # nf x nl x ng - result = res_ij.reshape(nf, nloc, self.filter_neuron[-1]).astype( - GLOBAL_NP_FLOAT_PRECISION - ) + result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1])) + result = xp.astype(result, get_xp_precision(xp, "global")) return ( result, None, @@ -743,3 +757,61 @@ def has_message_passing(self) -> bool: def need_sorted_nlist_for_lower(self) -> bool: """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" return False + + def serialize(self) -> dict: + """Serialize the descriptor to dict.""" + obj = self + data = { + "@class": "Descriptor", + "type": "se_e3_tebd", + "@version": 1, + "rcut": obj.rcut, + "rcut_smth": obj.rcut_smth, + "sel": obj.sel, + "ntypes": obj.ntypes, + "neuron": obj.neuron, + "tebd_dim": obj.tebd_dim, + "tebd_input_mode": obj.tebd_input_mode, + "set_davg_zero": obj.set_davg_zero, + "activation_function": obj.activation_function, + "resnet_dt": obj.resnet_dt, + # make deterministic + "precision": np.dtype(PRECISION_DICT[obj.precision]).name, + "embeddings": obj.embeddings.serialize(), + "env_mat": obj.env_mat.serialize(), + "exclude_types": obj.exclude_types, + "env_protection": obj.env_protection, + "smooth": obj.smooth, + "@variables": { + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), + }, + } + if obj.tebd_input_mode in ["strip"]: + data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) + return data + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeTTebd": + """Deserialize from dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version"), 1, 1) + data.pop("@class") + data.pop("type") + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + tebd_input_mode = data["tebd_input_mode"] + if tebd_input_mode in ["strip"]: + embeddings_strip = data.pop("embeddings_strip") + else: + embeddings_strip = None + se_ttebd = cls(**data) + + se_ttebd["davg"] = variables["davg"] + se_ttebd["dstd"] = variables["dstd"] + se_ttebd.embeddings = NetworkCollection.deserialize(embeddings) + if tebd_input_mode in ["strip"]: + se_ttebd.embeddings_strip = NetworkCollection.deserialize(embeddings_strip) + + return se_ttebd diff --git a/deepmd/jax/descriptor/se_t_tebd.py b/deepmd/jax/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..84e3d3f084 --- /dev/null +++ b/deepmd/jax/descriptor/se_t_tebd.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) +from deepmd.jax.utils.type_embed import ( + TypeEmbedNet, +) + + +@flax_module +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +@BaseDescriptor.register("se_e3_tebd") +@flax_module +class DescrptSeTTebd(DescrptSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_ttebd": + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/doc/model/train-se-e3-tebd.md b/doc/model/train-se-e3-tebd.md index 5935a8920a..49d0d80f42 100644 --- a/doc/model/train-se-e3-tebd.md +++ b/doc/model/train-se-e3-tebd.md @@ -1,7 +1,7 @@ -# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: The notation of `se_e3_tebd` is short for the three-body embedding descriptor with type embeddings, where the notation `se` denotes the Deep Potential Smooth Edition (DeepPot-SE). diff --git a/source/tests/array_api_strict/descriptor/se_t_tebd.py b/source/tests/array_api_strict/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..12fc04e69e --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_t_tebd.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) +from ..utils.type_embed import ( + TypeEmbedNet, +) + + +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"mean", "stddev"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings", "embeddings_strip"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) + + +class DescrptSeTTebd(DescrptSeTTebdDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "se_ttebd": + value = DescrptBlockSeTTebd.deserialize(value.serialize()) + elif name == "type_embedding": + value = TypeEmbedNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 3299a04c78..4712c28e53 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -15,6 +15,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -28,6 +30,16 @@ else: DescrptSeTTebdPT = None DescrptSeTTebdTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX +else: + DescrptSeTTebdJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.se_t_tebd import ( + DescrptSeTTebd as DescrptSeTTebdStrict, + ) +else: + DescrptSeTTebdStrict = None from deepmd.utils.argcheck import ( descrpt_se_e3_tebd_args, ) @@ -134,9 +146,14 @@ def skip_tf(self) -> bool: ) = self.param return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + jax_class = DescrptSeTTebdJAX + array_api_strict_class = DescrptSeTTebdStrict args = descrpt_se_e3_tebd_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -216,6 +233,26 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)