From 69906599ee4f51b79dd775bb7046811dcc94c775 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Wed, 1 Nov 2023 09:20:29 -0700 Subject: [PATCH] Use new ocp checlpointer PiperOrigin-RevId: 578540634 --- kauldron/checkpoints/__init__.py | 1 - kauldron/checkpoints/checkpointer.py | 9 +- kauldron/checkpoints/partial_loader.py | 2 +- kauldron/checkpoints/pytree_checkpoint.py | 109 ------------------ .../checkpoints/pytree_checkpoint_test.py | 58 ---------- 5 files changed, 3 insertions(+), 176 deletions(-) delete mode 100644 kauldron/checkpoints/pytree_checkpoint.py delete mode 100644 kauldron/checkpoints/pytree_checkpoint_test.py diff --git a/kauldron/checkpoints/__init__.py b/kauldron/checkpoints/__init__.py index d1dad5cd..35952187 100644 --- a/kauldron/checkpoints/__init__.py +++ b/kauldron/checkpoints/__init__.py @@ -22,4 +22,3 @@ from kauldron.checkpoints.partial_loader import CkptSource from kauldron.checkpoints.partial_loader import KauldronSource from kauldron.checkpoints.partial_loader import PartialLoader -from kauldron.checkpoints.pytree_checkpoint import PyTreeCheckpointer diff --git a/kauldron/checkpoints/checkpointer.py b/kauldron/checkpoints/checkpointer.py index ca61b1b3..5a9cddf8 100644 --- a/kauldron/checkpoints/checkpointer.py +++ b/kauldron/checkpoints/checkpointer.py @@ -26,7 +26,6 @@ from flax.training import orbax_utils import jax from kauldron.checkpoints import partial_loader -from kauldron.checkpoints import pytree_checkpoint from kauldron.utils import config_util import orbax.checkpoint as ocp @@ -130,7 +129,7 @@ def _ckpt_mgr(self) -> ocp.CheckpointManager: manager_cls = ocp.CheckpointManager ckpt_mgr = manager_cls( epath.Path(self.workdir) / CHECKPOINT_FOLDER_NAME, - pytree_checkpoint.PyTreeCheckpointer(), + ocp.StandardCheckpointer(), options=mgr_options, ) return ckpt_mgr @@ -141,18 +140,14 @@ def restore( step: int = -1, *, noop_if_missing: bool = False, - restore_kwargs: Optional[dict[str, Any]] = None, ) -> _T: """Restore state.""" - restore_kwargs = restore_kwargs or {} state = initial_state if self._ckpt_mgr.latest_step() is not None: step = self._absolute_step(step) - state = self._ckpt_mgr.restore( - step, items=initial_state, restore_kwargs=restore_kwargs - ) + state = self._ckpt_mgr.restore(step, initial_state) elif self.partial_initializer is not None: # No checkpoint if state is None: raise ValueError( diff --git a/kauldron/checkpoints/partial_loader.py b/kauldron/checkpoints/partial_loader.py index d1e36447..26ced918 100644 --- a/kauldron/checkpoints/partial_loader.py +++ b/kauldron/checkpoints/partial_loader.py @@ -222,7 +222,7 @@ def restore(self, item) -> Any: step=self.step, # Use `_NOT_RESTORED` sentinel value as `orbax` will silently # forward the additional values not present in the checkpoint. - initial_state=jax.tree_map(lambda _: _NOT_RESTORED, item), + initial_state=item, restore_kwargs=dict( restore_args=orbax_utils.restore_args_from_target(item), # Set `transforms={}` to indicate `orbax` to drop the keys not diff --git a/kauldron/checkpoints/pytree_checkpoint.py b/kauldron/checkpoints/pytree_checkpoint.py deleted file mode 100644 index 65d6df35..00000000 --- a/kauldron/checkpoints/pytree_checkpoint.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2023 The kauldron Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Checkpoint handler that support arbitrary `PyTree`.""" - -from __future__ import annotations - -from typing import Any, Optional, TypeVar - -from etils import epath -from flax.training import orbax_utils -import jax -import orbax.checkpoint as ocp - -# TODO(epot): Replace by `StandardCheckpointHandler` - -_T = TypeVar('_T') - - -class PyTreeCheckpointer(ocp.Checkpointer): - """Low level checkpointer that support arbitrary pytree. - - ```python - ckpt = kd.train.PyTreeCheckpointer() - - ckpt.save(path, state) - state = ckpt.restore(path) - ``` - """ - - def __init__(self): - super().__init__(PyTreeCheckpointHandler()) - - -class PyTreeCheckpointHandler(ocp.PyTreeCheckpointHandler): - """PyTree checkpoint handler that support arbitrary pytree.""" - - def __init__(self, **kwargs): - # Use OCDBT for faster performance (see - # https://orbax.readthedocs.io/en/latest/optimized_checkpointing.html) - super().__init__( - use_ocdbt=True, - write_tree_metadata=True, - **kwargs, - ) - - def restore( # pytype: disable=signature-mismatch - self, - directory: epath.Path, - item: Optional[_T] = None, - *, - restore_args: Optional[Any] = None, - **kwargs, - ) -> _T: - """Restores the provided item synchronously. - - Args: - directory: the directory to restore from. - item: an item with the same structure as that to be restored. If missing - the tree will be restored from the saved structure. - restore_args: Restore args - **kwargs: additional arguments for restore. - - Returns: - The restored item. - """ - if self.metadata(directory) is None: # Legacy checkpoint - print( - 'Loading DEPRECATED checkpoint (values flattened). Kauldron will ' - 'drop support for those at some point.' - ) - # Untested, not sure this works as the `TrainState.tree_flatten` has - # changed - - target, structure = jax.tree_util.tree_flatten(item) - restore_args = orbax_utils.restore_args_from_target(target) - state = super().restore( - directory=directory, - item=target, - restore_args=restore_args, - ) - return jax.tree_util.tree_unflatten(structure, state) - - if item is not None and restore_args is None: - # Restore args make sure the restored arrays match the `item` ( - # jax vs numpy array,...) - restore_args = orbax_utils.restore_args_from_target(item) - if item is None and restore_args is None: - raise ValueError( - 'Please comment on b/295122555 so orbax team prioritise this bug. ' - 'Currently cannot restore without `item=` or `restore_args=`.' - ) - return super().restore( - directory=directory, - item=item, - restore_args=restore_args, - **kwargs, - ) diff --git a/kauldron/checkpoints/pytree_checkpoint_test.py b/kauldron/checkpoints/pytree_checkpoint_test.py deleted file mode 100644 index 39be4326..00000000 --- a/kauldron/checkpoints/pytree_checkpoint_test.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2023 The kauldron Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test.""" - -import pathlib -from typing import Any - -import chex -import flax -import jax.numpy as jnp -from kauldron import kd -from kauldron.utils import assert_utils -import numpy as np -import orbax.checkpoint as ocp - - -@flax.struct.dataclass -class A: - a: Any - - -def test_checkpoint(tmp_path: pathlib.Path): - obj = A({ - 'a': A(a=np.asarray([1, 2, 3])), - 'b': { - 'c': 42, - 'd': jnp.asarray([4, 5]), - }, - }) - ckpt_mgr = ocp.CheckpointManager( - tmp_path / 'checkpoints', - kd.ckpts.PyTreeCheckpointer(), - ) - - ckpt_mgr.save(0, obj) - - new_obj = ckpt_mgr.restore(0, obj) - assert isinstance(new_obj, A) - chex.assert_trees_all_close(obj, new_obj) - assert_utils.assert_trees_all_same_type(obj, new_obj) - - # TODO(b/295122555): Restore once the orbax bug is fixed - # new_obj = ckpt_mgr.restore(0) - # assert isinstance(new_obj, dict) - # chex.assert_trees_all_close(ocp.utils.serialize_tree(obj), new_obj) - # assert_utils.assert_trees_all_same_type(obj, new_obj)