diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 8c3860cf39..0af7921c6d 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -6,8 +6,12 @@ import numpy as np +from deepmd.dpmodel.common import ( + NativeOP, +) from deepmd.jax.env import ( jnp, + nnx, ) @@ -35,3 +39,39 @@ def to_jax_array(array: Optional[np.ndarray]) -> Optional[jnp.ndarray]: if array is None: return None return jnp.array(array) + + +def flax_module( + module: NativeOP, +) -> nnx.Module: + """Convert a NativeOP to a Flax module. + + Parameters + ---------- + module : NativeOP + The NativeOP to convert. + + Returns + ------- + flax.nnx.Module + The Flax module. + + Examples + -------- + >>> @flax_module + ... class MyModule(NativeOP): + ... pass + """ + metas = set() + if not issubclass(type(nnx.Module), type(module)): + metas.add(type(module)) + if not issubclass(type(module), type(nnx.Module)): + metas.add(type(nnx.Module)) + + class MixedMetaClass(*metas): + pass + + class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): + pass + + return FlaxModule diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index 73ef6055e5..a9b0404970 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -13,6 +13,7 @@ NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, ) from deepmd.jax.common import ( + flax_module, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -28,6 +29,7 @@ ) +@flax_module class GatedAttentionLayer(GatedAttentionLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"in_proj", "out_proj"}: @@ -35,6 +37,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name == "attention_layer": @@ -44,6 +47,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class NeighborGatedAttention(NeighborGatedAttentionDP): def __setattr__(self, name: str, value: Any) -> None: if name == "attention_layers": @@ -53,6 +57,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class DescrptBlockSeAtten(DescrptBlockSeAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: @@ -71,6 +76,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: if name == "se_atten": diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 34e4aa6240..5a5a7f6bf0 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -5,10 +5,14 @@ import jax import jax.numpy as jnp +from flax import ( + nnx, +) jax.config.update("jax_enable_x64", True) __all__ = [ "jax", "jnp", + "nnx", ] diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index 6519648514..cac4cee092 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -5,10 +5,12 @@ from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( + flax_module, to_jax_array, ) +@flax_module class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 6517573b38..fc6e168c7b 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -17,10 +17,12 @@ make_multilayer_network, ) from deepmd.jax.common import ( + flax_module, to_jax_array, ) +@flax_module class NativeLayer(NativeLayerDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"}: @@ -28,11 +30,22 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) -NativeNet = make_multilayer_network(NativeLayer, NativeOP) -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +@flax_module +class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): + pass + + +@flax_module +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +@flax_module +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass +@flax_module class NetworkCollection(NetworkCollectionDP): NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = { "network": NativeNet, @@ -41,5 +54,6 @@ class NetworkCollection(NetworkCollectionDP): } +@flax_module class LayerNorm(LayerNormDP, NativeLayer): pass diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index bc7c469524..3143460244 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + flax_module, to_jax_array, ) from deepmd.jax.utils.network import ( @@ -12,6 +13,7 @@ ) +@flax_module class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: diff --git a/pyproject.toml b/pyproject.toml index 28fe114e01..9fa1425c2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', + 'flax>=0.8.0;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts]