Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): se_e3 #4286

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -25,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -122,26 +124,28 @@ def __init__(
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.trainable = trainable
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=2,
network_type="embedding_network",
)
for ii, embedding_idx in enumerate(
itertools.product(range(self.ntypes), repeat=self.embeddings.ndim)
itertools.product(range(self.ntypes), repeat=embeddings.ndim)
):
self.embeddings[embedding_idx] = EmbeddingNet(
embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(self.seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = sum(self.sel)
self.davg = np.zeros(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
Expand Down Expand Up @@ -299,20 +303,22 @@ def call(
The smooth switch function.
"""
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = np.append([0], np.cumsum(self.sel))
sec = self.sel_cumsum

ng = self.neuron[-1]
result = np.zeros([nf * nloc, ng], dtype=PRECISION_DICT[self.precision])
result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision))
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
rr = rr.reshape(nf * nloc, nnei, 4)
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand All @@ -325,23 +331,26 @@ def call(
# nfnl x nt_i x 3
rr_i = rr[:, sec[ti] : sec[ti + 1], 1:]
mm_i = exclude_mask[:, sec[ti] : sec[ti + 1]]
rr_i = rr_i * mm_i[:, :, None]
rr_i = rr_i * xp.astype(mm_i[:, :, None], rr_i.dtype)
# nfnl x nt_j x 3
rr_j = rr[:, sec[tj] : sec[tj + 1], 1:]
mm_j = exclude_mask[:, sec[tj] : sec[tj + 1]]
rr_j = rr_j * mm_j[:, :, None]
rr_j = rr_j * xp.astype(mm_j[:, :, None], rr_j.dtype)
# nfnl x nt_i x nt_j
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
# nfnl x nt_i x nt_j x 1
env_ij_reshape = env_ij[:, :, :, None]
# nfnl x nt_i x nt_j x ng
gg = self.embeddings[embedding_idx].call(env_ij_reshape)
# nfnl x nt_i x nt_j x ng
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = xp.sum(env_ij[:, :, :, None] * gg, axis=(1, 2))
njzjz marked this conversation as resolved.
Show resolved Hide resolved
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
result += res_ij
# nf x nloc x ng
result = result.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -369,8 +378,8 @@ def serialize(self) -> dict:
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
"trainable": self.trainable,
Expand Down
7 changes: 4 additions & 3 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,11 +572,12 @@ def call(self, x):
def clear(self):
"""Clear the network parameters to zero."""
for layer in self.layers:
layer.w.fill(0.0)
xp = array_api_compat.array_namespace(layer.w)
layer.w = xp.zeros_like(layer.w)
if layer.b is not None:
layer.b.fill(0.0)
layer.b = xp.zeros_like(layer.b)
if layer.idt is not None:
layer.idt.fill(0.0)
layer.idt = xp.zeros_like(layer.idt)

return NN

Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from deepmd.jax.descriptor.se_e2_r import (
DescrptSeR,
)
from deepmd.jax.descriptor.se_t import (
DescrptSeT,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptSeT",
"DescrptDPA1",
]
42 changes: 42 additions & 0 deletions deepmd/jax/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@BaseDescriptor.register("se_e3")
@BaseDescriptor.register("se_at")
@BaseDescriptor.register("se_a_3be")
@flax_module
class DescrptSeT(DescrptSeTDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
njzjz marked this conversation as resolved.
Show resolved Hide resolved
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeT(DescrptSeTDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

njzjz marked this conversation as resolved.
Show resolved Hide resolved
return super().__setattr__(name, value)
33 changes: 33 additions & 0 deletions source/tests/consistent/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -29,6 +31,14 @@
from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF
else:
DescrptSeTTF = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.se_t import DescrptSeT as DescrptSeTJAX
else:
DescrptSeTJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_t import DescrptSeT as DescrptSeTStrict
else:
DescrptSeTStrict = None
from deepmd.utils.argcheck import (
descrpt_se_t_args,
)
Expand Down Expand Up @@ -91,9 +101,14 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0 or excluded_types

skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
skip_jax = not INSTALLED_JAX

tf_class = DescrptSeTTF
dp_class = DescrptSeTDP
pt_class = DescrptSeTPT
jax_class = DescrptSeTJAX
array_api_strict_class = DescrptSeTStrict
args = descrpt_se_t_args()

def setUp(self):
Expand Down Expand Up @@ -168,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down
Loading