Skip to content

Commit

Permalink
feat(jax/array-api): se_e3
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 30, 2024
1 parent d165fee commit de84a87
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 21 deletions.
45 changes: 27 additions & 18 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -25,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -122,26 +124,28 @@ def __init__(
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=2,
network_type="embedding_network",
)
for ii, embedding_idx in enumerate(
itertools.product(range(self.ntypes), repeat=self.embeddings.ndim)
itertools.product(range(self.ntypes), repeat=embeddings.ndim)
):
self.embeddings[embedding_idx] = EmbeddingNet(
embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(self.seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = sum(self.sel)
self.davg = np.zeros(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
Expand Down Expand Up @@ -299,20 +303,22 @@ def call(
The smooth switch function.
"""
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))
sec = self.sel_cumsum

ng = self.neuron[-1]
result = np.zeros([nf * nloc, ng], dtype=PRECISION_DICT[self.precision])
result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision))
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
rr = rr.reshape(nf * nloc, nnei, 4)
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand All @@ -325,23 +331,26 @@ def call(
# nfnl x nt_i x 3
rr_i = rr[:, sec[ti] : sec[ti + 1], 1:]
mm_i = exclude_mask[:, sec[ti] : sec[ti + 1]]
rr_i = rr_i * mm_i[:, :, None]
rr_i = rr_i * xp.astype(mm_i[:, :, None], rr_i.dtype)
# nfnl x nt_j x 3
rr_j = rr[:, sec[tj] : sec[tj + 1], 1:]
mm_j = exclude_mask[:, sec[tj] : sec[tj + 1]]
rr_j = rr_j * mm_j[:, :, None]
rr_j = rr_j * xp.astype(mm_j[:, :, None], rr_j.dtype)
# nfnl x nt_i x nt_j
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
# nfnl x nt_i x nt_j x 1
env_ij_reshape = env_ij[:, :, :, None]
# nfnl x nt_i x nt_j x ng
gg = self.embeddings[embedding_idx].call(env_ij_reshape)
# nfnl x nt_i x nt_j x ng
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = xp.sum(env_ij[:, :, :, None] * gg, axis=(1, 2))
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
result += res_ij
# nf x nloc x ng
result = result.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -369,8 +378,8 @@ def serialize(self) -> dict:
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
"trainable": self.trainable,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,11 +572,12 @@ def call(self, x):
def clear(self):
"""Clear the network parameters to zero."""
for layer in self.layers:
layer.w.fill(0.0)
xp = array_api_compat.array_namespace(layer.w)
layer.w = xp.zeros_like(layer.w)
if layer.b is not None:
layer.b.fill(0.0)
layer.b = xp.zeros_like(layer.b)
if layer.idt is not None:
layer.idt.fill(0.0)
layer.idt = xp.zeros_like(layer.idt)

return NN

Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from deepmd.jax.descriptor.se_e2_r import (
DescrptSeR,
)
from deepmd.jax.descriptor.se_t import (
DescrptSeT,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptSeT",
"DescrptDPA1",
]
42 changes: 42 additions & 0 deletions deepmd/jax/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@BaseDescriptor.register("se_e3")
@BaseDescriptor.register("se_at")
@BaseDescriptor.register("se_a_3be")
@flax_module
class DescrptSeT(DescrptSeTDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeT(DescrptSeTDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
33 changes: 33 additions & 0 deletions source/tests/consistent/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -29,6 +31,14 @@
from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF
else:
DescrptSeTTF = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.se_t import DescrptSeT as DescrptSeTJAX
else:
DescrptSeTJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_t import DescrptSeT as DescrptSeTStrict
else:
DescrptSeTStrict = None
from deepmd.utils.argcheck import (
descrpt_se_t_args,
)
Expand Down Expand Up @@ -91,9 +101,14 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0 or excluded_types

skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
skip_jax = not INSTALLED_JAX

tf_class = DescrptSeTTF
dp_class = DescrptSeTDP
pt_class = DescrptSeTPT
jax_class = DescrptSeTJAX
array_api_strict_class = DescrptSeTStrict
args = descrpt_se_t_args()

def setUp(self):
Expand Down Expand Up @@ -168,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down

0 comments on commit de84a87

Please sign in to comment.