Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

callbacks.checkpoints uses old flax.training.checkpoints API #12

Open
JamesAllingham opened this issue Aug 29, 2023 · 1 comment
Open

Comments

@JamesAllingham
Copy link
Contributor

This means that trying to restore checkpoints using Orbax:

import orbax
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
loaded_state = orbax_checkpointer.restore("save_dir", item=init_state)

results in an error

FileNotFoundError: File not found: ss_model/checkpoint_-0.19632871448993683/.zarray. In many cases, this results from copying a checkpoint without using the `-a` flag.

A workaround is to use flax to restore the checkpoint

from flax.training import checkpoints
checkpoints.restore_checkpoint(ckpt_dir='save_dir', target=init_state)

but in the interest of using the more flexible (and future-proof) Orbax API, it would probably be a good idea to update the checkpointing callback.

What are your thoughts? To be honest, I am not super familiar with the various checkpointing APIs and the current status quo, so I may have made some incorrect assumptions above.

@cgarciae
Copy link
Owner

cgarciae commented Aug 29, 2023

This is a great question. I had 2 ideas:

  1. Refactor ciclo.checkpoint to use Orbax directly.
  2. Create an Adapter for Orbax checkpointers so you can use them directly inside loop e.g:

ciclo/ciclo/callbacks.py

Lines 571 to 588 in 1854850

# -------------------------------------------
# Adapters
# -------------------------------------------
if importlib.util.find_spec("clu") is not None:
from clu.periodic_actions import PeriodicAction
@dataclass(frozen=True)
class PeriodicActionCallback(LoopCallbackBase[S]):
action: PeriodicAction
def __loop_callback__(self, loop_state: LoopState[S]) -> CallbackOutput[S]:
self.action(loop_state.elapsed.steps, t=loop_state.elapsed.date)
return Logs(), loop_state.state
@functools.partial(register_adapter, cls=PeriodicAction)
def periodic_action_adapter(f: PeriodicAction):
return PeriodicActionCallback(f)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants