Skip to content

Commit

Permalink
Remove usages of orbax_utils.save_args_from_target, as this function …
Browse files Browse the repository at this point in the history
…does nothing (it used to control a checkpointing behavior that has since been optimized away).

PiperOrigin-RevId: 716314210
  • Loading branch information
cpgaffney1 authored and maxtext authors committed Jan 16, 2025
1 parent a580eb5 commit 3ad02ba
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, Optional, Union
from absl import flags
from etils import epath
from flax.training import orbax_utils, train_state
from flax.training import train_state
import grain.python as grain
import jax
import max_logging
Expand Down Expand Up @@ -330,6 +330,5 @@ def save_params_to_path(checkpoint_dir, params):
"""Save decode params in checkpoint at specified path."""
assert checkpoint_dir, "checkpoint_dir is not defined."
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target({"params": params})
orbax_checkpointer.save(checkpoint_dir, {"params": params}, save_args=save_args, force=True)
orbax_checkpointer.save(checkpoint_dir, {"params": params}, force=True)
print(f"Quantized params checkpoint saved at: {checkpoint_dir}")

0 comments on commit 3ad02ba

Please sign in to comment.