Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): se_t_tebd #4288

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 111 additions & 39 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.array_api import (
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand All @@ -26,9 +34,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -318,11 +323,15 @@
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
del mapping
nf, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nf, -1).shape[1] // 3
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
# nf x nall x tebd_dim
atype_embd_ext = self.type_embedding.call()[atype_ext]
atype_embd_ext = xp.reshape(
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
(nf, nall, self.tebd_dim),
)
# nfnl x tebd_dim
atype_embd = atype_embd_ext[:, :nloc, :]
grrg, g2, h2, rot_mat, sw = self.se_ttebd(
Expand All @@ -334,8 +343,8 @@
)
# nf x nloc x (ng + tebd_dim)
if self.concat_output_tebd:
grrg = np.concatenate(
[grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1
grrg = xp.concat(
[grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1
)
return grrg, rot_mat, None, None, sw

Expand Down Expand Up @@ -368,8 +377,8 @@
"env_protection": obj.env_protection,
"smooth": self.smooth,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
"trainable": self.trainable,
}
Expand Down Expand Up @@ -491,33 +500,35 @@
else:
self.embd_input_dim = 1

self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings[0] = EmbeddingNet(
embeddings[0] = EmbeddingNet(
self.embd_input_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, 0),
)
self.embeddings = embeddings
if self.tebd_input_mode in ["strip"]:
self.embeddings_strip = NetworkCollection(
embeddings_strip = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings_strip[0] = EmbeddingNet(
embeddings_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, 1),
)
self.embeddings_strip = embeddings_strip
else:
self.embeddings_strip = None

Expand Down Expand Up @@ -652,82 +663,85 @@
atype_embd_ext: Optional[np.ndarray] = None,
mapping: Optional[np.ndarray] = None,
):
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
)
nf, nloc, nnei, _ = dmatrix.shape
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
# nfnl x nnei
nlist = nlist.reshape(nf * nloc, nnei)
nlist = np.where(exclude_mask, nlist, -1)
nlist = xp.reshape(nlist, (nf * nloc, nnei))
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
# nfnl x nnei
nlist_mask = nlist != -1
# nfnl x nnei x 1
sw = np.where(nlist_mask[:, :, None], sw.reshape(nf * nloc, nnei, 1), 0.0)
sw = xp.where(
nlist_mask[:, :, None],
xp.reshape(sw, (nf * nloc, nnei, 1)),
xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype),
)

# nfnl x nnei x 4
dmatrix = dmatrix.reshape(nf * nloc, nnei, 4)
dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
# nfnl x nnei x 4
rr = dmatrix
rr = rr * exclude_mask[:, :, None]
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
# nfnl x nt_i x 3
rr_i = rr[:, :, 1:]
# nfnl x nt_j x 3
rr_j = rr[:, :, 1:]
# nfnl x nt_i x nt_j
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
# nfnl x nt_i x nt_j x 1
ss = np.expand_dims(env_ij, axis=-1)
ss = env_ij[..., None]

nlist_masked = np.where(nlist_mask, nlist, 0)
index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim))
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
# nfnl x nnei x tebd_dim
atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape(
nf * nloc, nnei, self.tebd_dim
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
atype_embd_nlist = xp.reshape(
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
)
# nfnl x nt_i x nt_j x tebd_dim
nlist_tebd_i = np.tile(
np.expand_dims(atype_embd_nlist, axis=2), [1, 1, self.nnei, 1]
)
nlist_tebd_j = np.tile(
np.expand_dims(atype_embd_nlist, axis=1), [1, self.nnei, 1, 1]
)
nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1))
nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1))
ng = self.neuron[-1]
njzjz marked this conversation as resolved.
Show resolved Hide resolved

if self.tebd_input_mode in ["concat"]:
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
ss = np.concatenate([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg = self.cal_g(ss, 0)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nt_i x nt_j x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
# nfnl x nt_i x nt_j x (tebd_dim * 2)
tt = np.concatenate([nlist_tebd_i, nlist_tebd_j], axis=-1)
tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1)
# nfnl x nt_i x nt_j x ng
gg_t = self.cal_g_strip(tt, 0)
if self.smooth:
gg_t = (
gg_t
* sw.reshape(nf * nloc, self.nnei, 1, 1)
* sw.reshape(nf * nloc, 1, self.nnei, 1)
* xp.reshape(sw, (nf * nloc, self.nnei, 1, 1))
* xp.reshape(sw, (nf * nloc, 1, self.nnei, 1))
)
# nfnl x nt_i x nt_j x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError

# nfnl x ng
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
res_ij = xp.sum(env_ij[:, :, :, None] * gg[:, :, :, :], axis=(1, 2))
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = res_ij.reshape(nf, nloc, self.filter_neuron[-1]).astype(
GLOBAL_NP_FLOAT_PRECISION
)
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand All @@ -743,3 +757,61 @@
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
obj = self
data = {
"@class": "Descriptor",
"type": "se_e3_tebd",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"tebd_dim": obj.tebd_dim,
"tebd_input_mode": obj.tebd_input_mode,
"set_davg_zero": obj.set_davg_zero,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
# make deterministic
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
"embeddings": obj.embeddings.serialize(),
"env_mat": obj.env_mat.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"smooth": obj.smooth,
"@variables": {
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.embeddings_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeTTebd":
"""Deserialize from dict."""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable env_mat is not used.
njzjz marked this conversation as resolved.
Show resolved Hide resolved
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None

Check warning on line 808 in deepmd/dpmodel/descriptor/se_t_tebd.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t_tebd.py#L808

Added line #L808 was not covered by tests
se_ttebd = cls(**data)

se_ttebd["davg"] = variables["davg"]
se_ttebd["dstd"] = variables["dstd"]
se_ttebd.embeddings = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
se_ttebd.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)

return se_ttebd
56 changes: 56 additions & 0 deletions deepmd/jax/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t_tebd import (
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
)
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)
from deepmd.jax.utils.type_embed import (
TypeEmbedNet,
)


@flax_module
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
njzjz marked this conversation as resolved.
Show resolved Hide resolved
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)


@BaseDescriptor.register("se_e3_tebd")
@flax_module
class DescrptSeTTebd(DescrptSeTTebdDP):
njzjz marked this conversation as resolved.
Show resolved Hide resolved
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_ttebd":
value = DescrptBlockSeTTebd.deserialize(value.serialize())
njzjz marked this conversation as resolved.
Show resolved Hide resolved
elif name == "type_embedding":
value = TypeEmbedNet.deserialize(value.serialize())
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return super().__setattr__(name, value)
4 changes: 2 additions & 2 deletions doc/model/train-se-e3-tebd.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ dpmodel_icon }}
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}

:::{note}
**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
:::

The notation of `se_e3_tebd` is short for the three-body embedding descriptor with type embeddings, where the notation `se` denotes the Deep Potential Smooth Edition (DeepPot-SE).
Expand Down
47 changes: 47 additions & 0 deletions source/tests/array_api_strict/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_t_tebd import (
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
)
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)
from ..utils.type_embed import (
TypeEmbedNet,
)


class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)


class DescrptSeTTebd(DescrptSeTTebdDP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_ttebd":
value = DescrptBlockSeTTebd.deserialize(value.serialize())
elif name == "type_embedding":
value = TypeEmbedNet.deserialize(value.serialize())
return super().__setattr__(name, value)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Loading