diff --git a/MaxText/checkpointing.py b/MaxText/checkpointing.py index efe4dd5d6..aff173c23 100644 --- a/MaxText/checkpointing.py +++ b/MaxText/checkpointing.py @@ -19,7 +19,7 @@ from typing import Any, Optional, Union from absl import flags from etils import epath -from flax.training import orbax_utils, train_state +from flax.training import train_state import grain.python as grain import jax import max_logging @@ -330,6 +330,5 @@ def save_params_to_path(checkpoint_dir, params): """Save decode params in checkpoint at specified path.""" assert checkpoint_dir, "checkpoint_dir is not defined." orbax_checkpointer = ocp.PyTreeCheckpointer() - save_args = orbax_utils.save_args_from_target({"params": params}) - orbax_checkpointer.save(checkpoint_dir, {"params": params}, save_args=save_args, force=True) + orbax_checkpointer.save(checkpoint_dir, {"params": params}, force=True) print(f"Quantized params checkpoint saved at: {checkpoint_dir}")