From 723d6ed9888abb885ac5df49b997a54c6fb00c3e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Sep 2024 06:36:09 -0400 Subject: [PATCH] fix metaclass Signed-off-by: Jinzhe Zeng --- deepmd/jax/common.py | 10 ++++++++-- deepmd/jax/utils/network.py | 3 --- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 0af7921c6d..9c144a41d1 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, Optional, overload, ) @@ -69,9 +70,14 @@ def flax_module( metas.add(type(nnx.Module)) class MixedMetaClass(*metas): - pass + def __call__(self, *args, **kwargs): + return type(nnx.Module).__call__(self, *args, **kwargs) class FlaxModule(module, nnx.Module, metaclass=MixedMetaClass): - pass + def __init_subclass__(cls, **kwargs) -> None: + return super().__init_subclass__(**kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + return super().__setattr__(name, value) return FlaxModule diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index fc6e168c7b..bbd6419663 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -35,12 +35,10 @@ class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): pass -@flax_module class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): pass -@flax_module class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): pass @@ -54,6 +52,5 @@ class NetworkCollection(NetworkCollectionDP): } -@flax_module class LayerNorm(LayerNormDP, NativeLayer): pass