diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py index 2fce6831fe..887ad7147e 100644 --- a/deepmd/jax/utils/network.py +++ b/deepmd/jax/utils/network.py @@ -26,6 +26,9 @@ class ArrayAPIParam(nnx.Param): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + def __array_namespace__(self, *args, **kwargs): return self.value.__array_namespace__(*args, **kwargs)