Skip to content

Commit

Permalink
fix metaclass
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 29, 2024
1 parent bac980e commit 723d6ed
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 8 additions & 2 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Optional,
overload,
)
Expand Down Expand Up @@ -69,9 +70,14 @@ def flax_module(
metas.add(type(nnx.Module))

class MixedMetaClass(*metas):
pass
def __call__(self, *args, **kwargs):
return type(nnx.Module).__call__(self, *args, **kwargs)

class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass):
pass
def __init_subclass__(cls, **kwargs) -> None:
return super().__init_subclass__(**kwargs)

def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)

return FlaxModule
3 changes: 0 additions & 3 deletions deepmd/jax/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ class NativeNet(make_multilayer_network(NativeLayer, NativeOP)):
pass


@flax_module
class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)):
pass


@flax_module
class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)):
pass

Expand All @@ -54,6 +52,5 @@ class NetworkCollection(NetworkCollectionDP):
}


@flax_module
class LayerNorm(LayerNormDP, NativeLayer):
pass

0 comments on commit 723d6ed

Please sign in to comment.