From de84a876a7e544a5b65ce7a10c6af55342ae6aef Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 30 Oct 2024 17:40:48 -0400 Subject: [PATCH] feat(jax/array-api): se_e3 Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/se_t.py | 45 +++++++++++-------- deepmd/dpmodel/utils/network.py | 7 +-- deepmd/jax/descriptor/__init__.py | 4 ++ deepmd/jax/descriptor/se_t.py | 42 +++++++++++++++++ .../tests/array_api_strict/descriptor/se_t.py | 32 +++++++++++++ .../tests/consistent/descriptor/test_se_t.py | 33 ++++++++++++++ 6 files changed, 142 insertions(+), 21 deletions(-) create mode 100644 deepmd/jax/descriptor/se_t.py create mode 100644 source/tests/array_api_strict/descriptor/se_t.py diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 4dc4c965fb..5bc5970a87 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -13,6 +14,10 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.common import ( + get_xp_precision, + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -25,9 +30,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -122,17 +124,18 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.trainable = trainable + self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()] in_dim = 1 # not considiering type embedding - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ntypes=self.ntypes, ndim=2, network_type="embedding_network", ) for ii, embedding_idx in enumerate( - itertools.product(range(self.ntypes), repeat=self.embeddings.ndim) + itertools.product(range(self.ntypes), repeat=embeddings.ndim) ): - self.embeddings[embedding_idx] = EmbeddingNet( + embeddings[embedding_idx] = EmbeddingNet( in_dim, self.neuron, self.activation_function, @@ -140,8 +143,9 @@ def __init__( self.precision, seed=child_seed(self.seed, ii), ) + self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) - self.nnei = np.sum(self.sel) + self.nnei = sum(self.sel) self.davg = np.zeros( [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] ) @@ -299,20 +303,22 @@ def call( The smooth switch function. """ del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) # nf x nloc x nnei x 4 rr, diff, ww = self.env_mat.call( coord_ext, atype_ext, nlist, self.davg, self.dstd ) nf, nloc, nnei, _ = rr.shape - sec = np.append([0], np.cumsum(self.sel)) + sec = self.sel_cumsum ng = self.neuron[-1] - result = np.zeros([nf * nloc, ng], dtype=PRECISION_DICT[self.precision]) + result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision)) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) - rr = rr.reshape(nf * nloc, nnei, 4) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + rr = xp.reshape(rr, (nf * nloc, nnei, 4)) + rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -325,23 +331,26 @@ def call( # nfnl x nt_i x 3 rr_i = rr[:, sec[ti] : sec[ti + 1], 1:] mm_i = exclude_mask[:, sec[ti] : sec[ti + 1]] - rr_i = rr_i * mm_i[:, :, None] + rr_i = rr_i * xp.astype(mm_i[:, :, None], rr_i.dtype) # nfnl x nt_j x 3 rr_j = rr[:, sec[tj] : sec[tj + 1], 1:] mm_j = exclude_mask[:, sec[tj] : sec[tj + 1]] - rr_j = rr_j * mm_j[:, :, None] + rr_j = rr_j * xp.astype(mm_j[:, :, None], rr_j.dtype) # nfnl x nt_i x nt_j - env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1) # nfnl x nt_i x nt_j x 1 env_ij_reshape = env_ij[:, :, :, None] # nfnl x nt_i x nt_j x ng gg = self.embeddings[embedding_idx].call(env_ij_reshape) # nfnl x nt_i x nt_j x ng - res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = xp.sum(env_ij[:, :, :, None] * gg, axis=(1, 2)) res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j)) result += res_ij # nf x nloc x ng - result = result.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION) + result = xp.reshape(result, (nf, nloc, ng)) + result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: @@ -369,8 +378,8 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "env_protection": self.env_protection, "@variables": { - "davg": self.davg, - "dstd": self.dstd, + "davg": to_numpy_array(self.davg), + "dstd": to_numpy_array(self.dstd), }, "type_map": self.type_map, "trainable": self.trainable, diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 5140a88c97..a81ddb69a6 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -572,11 +572,12 @@ def call(self, x): def clear(self): """Clear the network parameters to zero.""" for layer in self.layers: - layer.w.fill(0.0) + xp = array_api_compat.array_namespace(layer.w) + layer.w = xp.zeros_like(layer.w) if layer.b is not None: - layer.b.fill(0.0) + layer.b = xp.zeros_like(layer.b) if layer.idt is not None: - layer.idt.fill(0.0) + layer.idt = xp.zeros_like(layer.idt) return NN diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 3ed096f9c1..c1f0d9bc22 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -8,9 +8,13 @@ from deepmd.jax.descriptor.se_e2_r import ( DescrptSeR, ) +from deepmd.jax.descriptor.se_t import ( + DescrptSeT, +) __all__ = [ "DescrptSeA", "DescrptSeR", + "DescrptSeT", "DescrptDPA1", ] diff --git a/deepmd/jax/descriptor/se_t.py b/deepmd/jax/descriptor/se_t.py new file mode 100644 index 0000000000..029f4231fe --- /dev/null +++ b/deepmd/jax/descriptor/se_t.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) + + +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") +@flax_module +class DescrptSeT(DescrptSeTDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/descriptor/se_t.py b/source/tests/array_api_strict/descriptor/se_t.py new file mode 100644 index 0000000000..13e650aa17 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_t.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) + + +class DescrptSeT(DescrptSeTDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 833b76f6e1..1e6110705a 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -29,6 +31,14 @@ from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF else: DescrptSeTTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.se_t import DescrptSeT as DescrptSeTJAX +else: + DescrptSeTJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.se_t import DescrptSeT as DescrptSeTStrict +else: + DescrptSeTStrict = None from deepmd.utils.argcheck import ( descrpt_se_t_args, ) @@ -91,9 +101,14 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_jax = not INSTALLED_JAX + tf_class = DescrptSeTTF dp_class = DescrptSeTDP pt_class = DescrptSeTPT + jax_class = DescrptSeTJAX + array_api_strict_class = DescrptSeTStrict args = descrpt_se_t_args() def setUp(self): @@ -168,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)