Skip to content

Commit

Permalink
feat(jax/array-api): hybrid descriptor
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 29, 2024
1 parent 159361d commit b9cfc7f
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 5 deletions.
13 changes: 8 additions & 5 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
), f"number of atom types in {ii}th descriptor {self.descrpt_list[0].__class__.__name__} does not match others"
# if hybrid sel is larger than sub sel, the nlist needs to be cut for each type
hybrid_sel = self.get_sel()
self.nlist_cut_idx: list[np.ndarray] = []
nlist_cut_idx: list[np.ndarray] = []
if self.mixed_types() and not all(
descrpt.mixed_types() for descrpt in self.descrpt_list
):
Expand All @@ -92,7 +93,8 @@ def __init__(
cut_idx = np.concatenate(
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
)
self.nlist_cut_idx.append(cut_idx)
nlist_cut_idx.append(cut_idx)
self.nlist_cut_idx = nlist_cut_idx

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -242,6 +244,7 @@ def call(
sw
The smooth switch function.
"""
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
out_descriptor = []
out_gr = []
out_g2 = None
Expand All @@ -258,7 +261,7 @@ def call(
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):
# cut the nlist to the correct length
if self.mixed_types() == descrpt.mixed_types():
nl = nlist[:, :, nci]
nl = xp.take(nlist, nci, axis=2)
else:
# mixed_types is True, but descrpt.mixed_types is False
assert nl_distinguish_types is not None
Expand All @@ -268,8 +271,8 @@ def call(
if gr is not None:
out_gr.append(gr)

out_descriptor = np.concatenate(out_descriptor, axis=-1)
out_gr = np.concatenate(out_gr, axis=-2) if out_gr else None
out_descriptor = xp.concat(out_descriptor, axis=-1)
out_gr = xp.concat(out_gr, axis=-2) if out_gr else None
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -13,4 +16,5 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptHybrid",
]
26 changes: 26 additions & 0 deletions deepmd/jax/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("hybrid")
@flax_module
class DescrptHybrid(DescrptHybridDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
elif name in {"descrpt_list"}:
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]

return super().__setattr__(name, value)
19 changes: 19 additions & 0 deletions source/tests/array_api_strict/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dpa1 import (
DescrptDPA1,
)
from .hybrid import (
DescrptHybrid,
)
from .se_e2_a import (
DescrptSeA,
)
from .se_e2_r import (
DescrptSeR,
)

__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptHybrid",
]
11 changes: 11 additions & 0 deletions source/tests/array_api_strict/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.make_base_descriptor import (
make_base_descriptor,
)

# no type annotations standard in array api
BaseDescriptor = make_base_descriptor(Any)
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from ..utils.type_embed import (
TypeEmbedNet,
)
from .base_descriptor import (
BaseDescriptor,
)


class GatedAttentionLayer(GatedAttentionLayerDP):
Expand Down Expand Up @@ -72,6 +75,8 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
class DescrptDPA1(DescrptDPA1DP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_atten":
Expand Down
24 changes: 24 additions & 0 deletions source/tests/array_api_strict/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)

from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP

from ..common import (
to_array_api_strict_array,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("hybrid")
class DescrptHybrid(DescrptHybridDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"nlist_cut_idx"}:
value = [to_array_api_strict_array(vv) for vv in value]
elif name in {"descrpt_list"}:
value = [BaseDescriptor.deserialize(vv.serialize()) for vv in value]

return super().__setattr__(name, value)
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from ..utils.network import (
NetworkCollection,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
Expand Down
5 changes: 5 additions & 0 deletions source/tests/array_api_strict/descriptor/se_e2_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@
from ..utils.network import (
NetworkCollection,
)
from .base_descriptor import (
BaseDescriptor,
)


@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
class DescrptSeR(DescrptSeRDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
Expand Down
35 changes: 35 additions & 0 deletions source/tests/consistent/descriptor/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
CommonTest,
Expand All @@ -28,6 +30,16 @@
from deepmd.tf.descriptor.hybrid import DescrptHybrid as DescrptHybridTF
else:
DescrptHybridTF = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX
else:
DescrptHybridJAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.hybrid import (
DescrptHybrid as DescrptHybridStrict,
)
else:
DescrptHybridStrict = None
from deepmd.utils.argcheck import (
descrpt_hybrid_args,
)
Expand Down Expand Up @@ -68,8 +80,13 @@ def data(self) -> dict:
tf_class = DescrptHybridTF
dp_class = DescrptHybridDP
pt_class = DescrptHybridPT
jax_class = DescrptHybridJAX
array_api_strict_class = DescrptHybridStrict
args = descrpt_hybrid_args()

skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)

Expand Down Expand Up @@ -132,5 +149,23 @@ def eval_pt(self, pt_obj: Any) -> Any:
self.box,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

0 comments on commit b9cfc7f

Please sign in to comment.