diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index aa41e35f69..a07dc5e2df 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -73,6 +73,7 @@ def serialize_from_file(model_file: str) -> dict: state = data.state model_def_script = data.model_def_script model = get_model(model_def_script) + nnx.update(model, state) model_dict = model.serialize() data = { "backend": "JAX",