Skip to content

Commit

Permalink
apply flax.nnx.Module
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 29, 2024
1 parent e7aeca0 commit bac980e
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 3 deletions.
40 changes: 40 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.jax.env import (
jnp,
nnx,
)


Expand Down Expand Up @@ -35,3 +39,39 @@ def to_jax_array(array: Optional[np.ndarray]) -> Optional[jnp.ndarray]:
if array is None:
return None
return jnp.array(array)


def flax_module(
module: NativeOP,
) -> nnx.Module:
"""Convert a NativeOP to a Flax module.
Parameters
----------
module : NativeOP
The NativeOP to convert.
Returns
-------
flax.nnx.Module
The Flax module.
Examples
--------
>>> @flax_module
... class MyModule(NativeOP):
... pass
"""
metas = set()
if not issubclass(type(nnx.Module), type(module)):
metas.add(type(module))
if not issubclass(type(module), type(nnx.Module)):
metas.add(type(nnx.Module))

class MixedMetaClass(*metas):
pass

class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
pass

return FlaxModule
6 changes: 6 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
)
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
Expand All @@ -28,13 +29,15 @@
)


@flax_module
class GatedAttentionLayer(GatedAttentionLayerDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"in_proj", "out_proj"}:
value = NativeLayer.deserialize(value.serialize())
return super().__setattr__(name, value)


@flax_module
class NeighborGatedAttentionLayer(NeighborGatedAttentionLayerDP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "attention_layer":
Expand All @@ -44,6 +47,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@flax_module
class NeighborGatedAttention(NeighborGatedAttentionDP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "attention_layers":
Expand All @@ -53,6 +57,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@flax_module
class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
Expand All @@ -71,6 +76,7 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)


@flax_module
class DescrptDPA1(DescrptDPA1DP):
def __setattr__(self, name: str, value: Any) -> None:
if name == "se_atten":
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

import jax
import jax.numpy as jnp
from flax import (
nnx,
)

jax.config.update("jax_enable_x64", True)

__all__ = [
"jax",
"jnp",
"nnx",
]
2 changes: 2 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)


@flax_module
class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
Expand Down
20 changes: 17 additions & 3 deletions deepmd/jax/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,35 @@
make_multilayer_network,
)
from deepmd.jax.common import (
flax_module,
to_jax_array,
)


@flax_module
class NativeLayer(NativeLayerDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"w", "b", "idt"}:
value = to_jax_array(value)
return super().__setattr__(name, value)


NativeNet = make_multilayer_network(NativeLayer, NativeOP)
EmbeddingNet = make_embedding_network(NativeNet, NativeLayer)
FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)
@flax_module
class NativeNet(make_multilayer_network(NativeLayer, NativeOP)):
pass


@flax_module
class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
pass


@flax_module
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
pass


@flax_module
class NetworkCollection(NetworkCollectionDP):
NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = {
"network": NativeNet,
Expand All @@ -41,5 +54,6 @@ class NetworkCollection(NetworkCollectionDP):
}


@flax_module
class LayerNorm(LayerNormDP, NativeLayer):
pass
2 changes: 2 additions & 0 deletions deepmd/jax/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.jax.common import (
flax_module,
to_jax_array,
)
from deepmd.jax.utils.network import (
EmbeddingNet,
)


@flax_module
class TypeEmbedNet(TypeEmbedNetDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"econf_tebd"}:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ cu12 = [
]
jax = [
'jax>=0.4.33;python_version>="3.10"',
'flax>=0.8.0;python_version>="3.10"',
]

[tool.deepmd_build_backend.scripts]
Expand Down

0 comments on commit bac980e

Please sign in to comment.