diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py index e0ac222524..897863ec0f 100644 --- a/deepmd/dpmodel/descriptor/se_atten_v2.py +++ b/deepmd/dpmodel/descriptor/se_atten_v2.py @@ -11,6 +11,9 @@ DEFAULT_PRECISION, PRECISION_DICT, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils import ( NetworkCollection, ) @@ -146,8 +149,8 @@ def serialize(self) -> dict: "exclude_types": obj.exclude_types, "env_protection": obj.env_protection, "@variables": { - "davg": obj["davg"], - "dstd": obj["dstd"], + "davg": to_numpy_array(obj["davg"]), + "dstd": to_numpy_array(obj["dstd"]), }, ## to be updated when the options are supported. "trainable": self.trainable, diff --git a/deepmd/jax/descriptor/se_atten_v2.py b/deepmd/jax/descriptor/se_atten_v2.py new file mode 100644 index 0000000000..a7ef4035cd --- /dev/null +++ b/deepmd/jax/descriptor/se_atten_v2.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, +) + + +@BaseDescriptor.register("se_atten_v2") +class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP): + pass diff --git a/source/tests/array_api_strict/descriptor/se_atten_v2.py b/source/tests/array_api_strict/descriptor/se_atten_v2.py new file mode 100644 index 0000000000..a2e06ac0e2 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_atten_v2.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP + +from .dpa1 import ( + DescrptDPA1, +) + + +class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP): + pass diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index a3fe4e98b4..f4a8119ca3 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -16,6 +16,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -30,6 +32,18 @@ ) else: DescrptSeAttenV2PT = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.se_atten_v2 import ( + DescrptSeAttenV2 as DescrptSeAttenV2JAX, + ) +else: + DescrptSeAttenV2JAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.se_atten_v2 import ( + DescrptSeAttenV2 as DescrptSeAttenV2Strict, + ) +else: + DescrptSeAttenV2Strict = None DescrptSeAttenV2TF = None from deepmd.utils.argcheck import ( descrpt_se_atten_args, @@ -175,9 +189,70 @@ def skip_dp(self) -> bool: def skip_tf(self) -> bool: return True + @property + def skip_jax(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_array_api_strict(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return ( + not INSTALLED_ARRAY_API_STRICT + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + ) + tf_class = DescrptSeAttenV2TF dp_class = DescrptSeAttenV2DP pt_class = DescrptSeAttenV2PT + jax_class = DescrptSeAttenV2JAX + array_api_strict_class = DescrptSeAttenV2Strict args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) def setUp(self): @@ -244,6 +319,26 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],)