Skip to content

Commit

Permalink
Use new ocp checlpointer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578540634
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 1, 2023
1 parent d20e4ae commit 6990659
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 176 deletions.
1 change: 0 additions & 1 deletion kauldron/checkpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@
from kauldron.checkpoints.partial_loader import CkptSource
from kauldron.checkpoints.partial_loader import KauldronSource
from kauldron.checkpoints.partial_loader import PartialLoader
from kauldron.checkpoints.pytree_checkpoint import PyTreeCheckpointer
9 changes: 2 additions & 7 deletions kauldron/checkpoints/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flax.training import orbax_utils
import jax
from kauldron.checkpoints import partial_loader
from kauldron.checkpoints import pytree_checkpoint
from kauldron.utils import config_util
import orbax.checkpoint as ocp

Expand Down Expand Up @@ -130,7 +129,7 @@ def _ckpt_mgr(self) -> ocp.CheckpointManager:
manager_cls = ocp.CheckpointManager
ckpt_mgr = manager_cls(
epath.Path(self.workdir) / CHECKPOINT_FOLDER_NAME,
pytree_checkpoint.PyTreeCheckpointer(),
ocp.StandardCheckpointer(),
options=mgr_options,
)
return ckpt_mgr
Expand All @@ -141,18 +140,14 @@ def restore(
step: int = -1,
*,
noop_if_missing: bool = False,
restore_kwargs: Optional[dict[str, Any]] = None,
) -> _T:
"""Restore state."""
restore_kwargs = restore_kwargs or {}

state = initial_state
if self._ckpt_mgr.latest_step() is not None:
step = self._absolute_step(step)

state = self._ckpt_mgr.restore(
step, items=initial_state, restore_kwargs=restore_kwargs
)
state = self._ckpt_mgr.restore(step, initial_state)
elif self.partial_initializer is not None: # No checkpoint
if state is None:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion kauldron/checkpoints/partial_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def restore(self, item) -> Any:
step=self.step,
# Use `_NOT_RESTORED` sentinel value as `orbax` will silently
# forward the additional values not present in the checkpoint.
initial_state=jax.tree_map(lambda _: _NOT_RESTORED, item),
initial_state=item,
restore_kwargs=dict(
restore_args=orbax_utils.restore_args_from_target(item),
# Set `transforms={}` to indicate `orbax` to drop the keys not
Expand Down
109 changes: 0 additions & 109 deletions kauldron/checkpoints/pytree_checkpoint.py

This file was deleted.

58 changes: 0 additions & 58 deletions kauldron/checkpoints/pytree_checkpoint_test.py

This file was deleted.

0 comments on commit 6990659

Please sign in to comment.