diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index 90920879c2..ffd58daf5e 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.common import ( + ArrayAPIVariable, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -11,6 +12,8 @@ def base_atomic_model_set_attr(name, value): if name in {"out_bias", "out_std"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "pair_excl" and value is not None: value = PairExcludeMask(value.ntypes, value.exclude_types) elif name == "atom_excl" and value is not None: diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 9c144a41d1..f372e97eb5 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return FlaxModule + + +class ArrayAPIVariable(nnx.Variable): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + + def __array_namespace__(self, *args, **kwargs): + return self.value.__array_namespace__(*args, **kwargs) + + def __dlpack__(self, *args, **kwargs): + return self.value.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.value.__dlpack_device__(*args, **kwargs) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index 0528e4bb93..fef9bd5448 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -13,6 +13,7 @@ NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, ) from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings", "embeddings_strip"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index d1a6e9a8d9..31c147ad9d 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index f979db4d41..cef1f667b3 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: "aparam_inv_std", }: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "emask": value = AtomExcludeMask(value.ntypes, value.exclude_types) elif name == "nets": diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index a6cf210f94..18d13d9400 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index a07dc5e2df..43070f8a07 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -13,9 +13,6 @@ BaseModel, get_model, ) -from deepmd.jax.utils.network import ( - ArrayAPIParam, -) def deserialize_to_file(model_file: str, data: dict) -> None: @@ -31,14 +28,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None: if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - state = nnx.state(model, ArrayAPIParam) + _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") ) as checkpointer: checkpointer.save( Path(model_file).absolute(), ocp.args.Composite( - state=ocp.args.StandardSave(state), + state=ocp.args.StandardSave(state.to_pure_dict()), model_def_script=ocp.args.JsonSave(model_def_script), ), ) @@ -71,9 +68,22 @@ def serialize_from_file(model_file: str) -> dict: ), ) state = data.state + + # convert str "1" to int 1 key + def convert_str_to_int_key(item: dict): + for key, value in item.copy().items(): + if isinstance(value, dict): + convert_str_to_int_key(value) + if key.isdigit(): + item[int(key)] = item.pop(key) + + convert_str_to_int_key(state) + model_def_script = data.model_def_script - model = get_model(model_def_script) - nnx.update(model, state) + abstract_model = get_model(model_def_script) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(state) + model = nnx.merge(graphdef, abstract_state) model_dict = model.serialize() data = { "backend": "JAX", diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index 3143460244..30cd9f45a9 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) if name in {"embedding_net"}: value = EmbeddingNet.deserialize(value.serialize()) return super().__setattr__(name, value) diff --git a/pyproject.toml b/pyproject.toml index 2e95d70614..6bc1065ac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', - 'flax>=0.8.0;python_version>="3.10"', + 'flax>=0.10.0;python_version>="3.10"', 'orbax-checkpoint;python_version>="3.10"', # The pinning of ml_dtypes may conflict with TF # 'jax-ai-stack;python_version>="3.10"',