Skip to content

Commit

Permalink
feat(jax/array-api): se_atten_v2
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 30, 2024
1 parent d165fee commit cbfb5d1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 2 deletions.
7 changes: 5 additions & 2 deletions deepmd/dpmodel/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
DEFAULT_PRECISION,
PRECISION_DICT,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils import (
NetworkCollection,
)
Expand Down Expand Up @@ -146,8 +149,8 @@ def serialize(self) -> dict:
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
"davg": to_numpy_array(obj["davg"]),
"dstd": to_numpy_array(obj["dstd"]),
},
## to be updated when the options are supported.
"trainable": self.trainable,
Expand Down
13 changes: 13 additions & 0 deletions deepmd/jax/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
from deepmd.jax.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.jax.descriptor.dpa1 import (
DescrptDPA1,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
pass
10 changes: 10 additions & 0 deletions source/tests/array_api_strict/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP

from .dpa1 import (
DescrptDPA1,
)


class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
pass
95 changes: 95 additions & 0 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)

from ..common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
CommonTest,
parameterized,
Expand All @@ -30,6 +32,18 @@
)
else:
DescrptSeAttenV2PT = None
if INSTALLED_JAX:
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2 as DescrptSeAttenV2JAX,
)
else:
DescrptSeAttenV2JAX = None
if INSTALLED_ARRAY_API_STRICT:
from ...array_api_strict.descriptor.se_atten_v2 import (
DescrptSeAttenV2 as DescrptSeAttenV2Strict,
)
else:
DescrptSeAttenV2Strict = None
DescrptSeAttenV2TF = None
from deepmd.utils.argcheck import (
descrpt_se_atten_args,
Expand Down Expand Up @@ -175,9 +189,70 @@ def skip_dp(self) -> bool:
def skip_tf(self) -> bool:
return True

@property
def skip_jax(self) -> bool:
(
tebd_dim,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)

@property
def skip_array_api_strict(self) -> bool:
(
tebd_dim,
resnet_dt,
type_one_side,
attn,
attn_layer,
attn_dotr,
excluded_types,
env_protection,
set_davg_zero,
scaling_factor,
normalize,
temperature,
ln_eps,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
normalize,
temperature,
)
)

tf_class = DescrptSeAttenV2TF
dp_class = DescrptSeAttenV2DP
pt_class = DescrptSeAttenV2PT
jax_class = DescrptSeAttenV2JAX
array_api_strict_class = DescrptSeAttenV2Strict
args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))

def setUp(self):
Expand Down Expand Up @@ -244,6 +319,26 @@ def eval_pt(self, pt_obj: Any) -> Any:
mixed_types=True,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_descriptor(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
return self.eval_array_api_strict_descriptor(
array_api_strict_obj,
self.natoms,
self.coords,
self.atype,
self.box,
mixed_types=True,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)

Expand Down

0 comments on commit cbfb5d1

Please sign in to comment.