diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 0af7921c6d..9c144a41d1 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, Optional, overload, ) @@ -69,9 +70,14 @@ def flax_module( metas.add(type(nnx.Module)) class MixedMetaClass(*metas): - pass + def __call__(self, *args, **kwargs): + return type(nnx.Module).__call__(self, *args, **kwargs) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): - pass + def __init_subclass__(cls, **kwargs) -> None: + return super().__init_subclass__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + return super().__setattr__(name, value) return FlaxModule diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index fc6e168c7b..bbd6419663 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -35,12 +35,10 @@ class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): pass -@flax_module class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): pass -@flax_module class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass @@ -54,6 +52,5 @@ class NetworkCollection(NetworkCollectionDP): } -@flax_module class LayerNorm(LayerNormDP, NativeLayer): pass