Skip to content

Commit

Permalink
feat(jax/array-api): se_e2_a
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 15, 2024
1 parent 5c092e6 commit 21393f4
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 16 deletions.
107 changes: 100 additions & 7 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
EnvMat,
Expand Down Expand Up @@ -186,31 +190,33 @@ def __init__(
self.reinit_exclude(exclude_types)

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
embeddings = NetworkCollection(
ntypes=self.ntypes,
ndim=(1 if self.type_one_side else 2),
network_type="embedding_network",
)
for ii, embedding_idx in enumerate(
itertools.product(range(self.ntypes), repeat=self.embeddings.ndim)
itertools.product(range(self.ntypes), repeat=embeddings.ndim)
):
self.embeddings[embedding_idx] = EmbeddingNet(
embeddings[embedding_idx] = EmbeddingNet(
in_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
seed=child_seed(seed, ii),
)
self.embeddings = embeddings
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.nnei = np.sum(self.sel).item()
self.davg = np.zeros(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
self.dstd = np.ones(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
)
self.orig_sel = self.sel
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]

def __setitem__(self, key, value):
if key in ("avg", "data_avg", "davg"):
Expand Down Expand Up @@ -321,8 +327,9 @@ def cal_g(
ss,
embedding_idx,
):
xp = array_api_compat.array_namespace(ss)
nf_times_nloc, nnei = ss.shape[0:2]
ss = ss.reshape(nf_times_nloc, nnei, 1)
ss = xp.reshape(ss, (nf_times_nloc, nnei, 1))
# (nf x nloc) x nnei x ng
gg = self.embeddings[embedding_idx].call(ss)
return gg
Expand Down Expand Up @@ -444,8 +451,8 @@ def serialize(self) -> dict:
"env_mat": self.env_mat.serialize(),
"embeddings": self.embeddings.serialize(),
"@variables": {
"davg": self.davg,
"dstd": self.dstd,
"davg": to_numpy_array(self.davg),
"dstd": to_numpy_array(self.dstd),
},
"type_map": self.type_map,
}
Expand Down Expand Up @@ -497,3 +504,89 @@ def update_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist


class DescrptSeAArrayAPI(DescrptSeA):
def call(
self,
coord_ext,
atype_ext,
nlist,
mapping: Optional[np.ndarray] = None,
):
"""Compute the descriptor.
Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping from extended to lcoal region. not used by this descriptor.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
if not self.type_one_side:
raise NotImplementedError(
"type_one_side == False is not supported in DescrptSeAArrayAPI"
)
del mapping
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
input_dtype = coord_ext.dtype
# nf x nloc x nnei x 4
rr, diff, ww = self.env_mat.call(
coord_ext, atype_ext, nlist, self.davg, self.dstd
)
nf, nloc, nnei, _ = rr.shape
sec = xp.asarray(self.sel_cumsum)

ng = self.neuron[-1]
gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype)
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# merge nf and nloc axis, so for type_one_side == False,
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, self.dstd.dtype)

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
):
(tt,) = embedding_idx
mm = exclude_mask[:, sec[tt] : sec[tt + 1]]
tr = rr[:, sec[tt] : sec[tt + 1], :]
tr = tr * xp.astype(mm[:, :, None], tr.dtype)
ss = tr[..., 0:1]
gg = self.cal_g(ss, embedding_idx)
# gr_tmp = xp.einsum("lni,lnj->lij", gg, tr)
gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
gr += gr_tmp
gr = xp.reshape(gr, (nf, nloc, ng, 4))
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
return grrg, gr[..., 1:], None, None, ww
19 changes: 10 additions & 9 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,21 @@ def nlist_distinguish_types(
xp = array_api_compat.array_namespace(nlist, atype)
nf, nloc, _ = nlist.shape
ret_nlist = []
tmp_atype = xp.tile(atype[:, None], [1, nloc, 1])
tmp_atype = xp.tile(atype[:, None, :], (1, nloc, 1))
mask = nlist == -1
tnlist_0 = nlist.copy()
tnlist_0[mask] = 0
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze()
tnlist = xp.where(mask, -1, tnlist)
tnlist_0 = xp.where(mask, xp.zeros_like(nlist), nlist)
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2)
tnlist = xp.where(mask, xp.full_like(tnlist, -1), tnlist)
snsel = tnlist.shape[2]
for ii, ss in enumerate(sel):
pick_mask = (tnlist == ii).astype(xp.int32)
sorted_indices = xp.argsort(-pick_mask, kind="stable", axis=-1)
pick_mask = xp.astype(tnlist == ii, xp.int32)
sorted_indices = xp.argsort(-pick_mask, stable=True, axis=-1)
pick_mask_sorted = -xp.sort(-pick_mask, axis=-1)
inlist = xp_take_along_axis(nlist, sorted_indices, axis=2)
inlist = xp.where(~pick_mask_sorted.astype(bool), -1, inlist)
ret_nlist.append(xp.split(inlist, [ss, snsel - ss], axis=-1)[0])
inlist = xp.where(
~xp.astype(pick_mask_sorted, xp.bool), xp.full_like(inlist, -1), inlist
)
ret_nlist.append(inlist[..., :ss])
ret = xp.concat(ret_nlist, axis=-1)
return ret

Expand Down
33 changes: 33 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.jax.utils.network import (
NetworkCollection,
)


@flax_module
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
32 changes: 32 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP

from ..common import (
to_array_api_strict_array,
)
from ..utils.exclude_mask import (
PairExcludeMask,
)
from ..utils.network import (
NetworkCollection,
)


class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_array_api_strict_array(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
elif name == "env_mat":
# env_mat doesn't store any value
pass
elif name == "emask":
value = PairExcludeMask(value.ntypes, value.exclude_types)

return super().__setattr__(name, value)
55 changes: 55 additions & 0 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -33,6 +35,17 @@
descrpt_se_a_args,
)

if INSTALLED_JAX:
from deepmd.jax.descriptor.se_e2_a import DescrptSeA as DescrptSeAJAX
else:
DescrptSeAJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_e2_a import (
DescrptSeA as DescrptSeAArrayAPIStrict,
)
else:
DescrptSeAArrayAPI = None


@parameterized(
(True, False), # resnet_dt
Expand Down Expand Up @@ -98,9 +111,33 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0

@property
def skip_jax(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
env_protection,
) = self.param
return not type_one_side or not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
type_one_side,
excluded_types,
precision,
env_protection,
) = self.param
return not type_one_side or not INSTALLED_JAX

tf_class = DescrptSeATF
dp_class = DescrptSeADP
pt_class = DescrptSeAPT
jax_class = DescrptSeAJAX
array_api_strict_class = DescrptSeAArrayAPIStrict
args = descrpt_se_a_args()

def setUp(self):
Expand Down Expand Up @@ -177,6 +214,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down

0 comments on commit 21393f4

Please sign in to comment.