Skip to content

Commit

Permalink
Merge pull request #920 from AI-Hypercomputer:rdyro-upstream-expose-c…
Browse files Browse the repository at this point in the history
…heckpointing-params

PiperOrigin-RevId: 679813127
  • Loading branch information
maxtext authors committed Sep 28, 2024
2 parents 5a8f75b + 6cea505 commit 94e6907
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
9 changes: 9 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ checkpoint_period: 10_000
enable_single_replica_ckpt_restoring: False

force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?

# checkpointing using orbax has two important parameters: array driver
# and its underlying storage - the kvstore (preferably ocdbt)
# orbax supports setting a target file size, chunking a single
# large arrays into small physical files (<2GB) can speed up distributed and over
# the network loading enormously
checkpoint_storage_target_data_file_size_bytes: 2147483648
checkpoint_storage_use_ocdbt: True
checkpoint_storage_use_zarr3: True
############################### END CHECKPOINTING ##################################


Expand Down
20 changes: 11 additions & 9 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

Transformer = models.Transformer
EPS = 1e-8
_CHUNK_BYTE_SIZE = 2 * 1024**3
_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3


def validate_train_config(config):
Expand Down Expand Up @@ -210,20 +210,23 @@ def save_checkpoint(
)

# specify chunk_byte_size to force orbax to control maximum file size in checkpoint
save_args = jax.tree.map(lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=_CHUNK_BYTE_SIZE), state)
chunk_byte_size = _DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
if config:
chunk_byte_size = config.checkpoint_storage_target_data_file_size_bytes
save_args = jax.tree.map(lambda _: orbax.checkpoint.SaveArgs(chunk_byte_size=chunk_byte_size), state)

if isinstance(checkpoint_manager, emergency_checkpoint_manager.CheckpointManager):
return checkpoint_manager.save(
step,
args=orbax.checkpoint.args.PyTreeSave(item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE),
args=orbax.checkpoint.args.PyTreeSave(item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size),
)

if dataset_type == "grain":
return checkpoint_manager.save(
step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeSave(
item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE
item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size
),
iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator),
),
Expand All @@ -233,7 +236,7 @@ def save_checkpoint(
step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeSave(
item=state, save_args=save_args, ocdbt_target_data_file_size=_CHUNK_BYTE_SIZE
item=state, save_args=save_args, ocdbt_target_data_file_size=chunk_byte_size
)
),
)
Expand Down Expand Up @@ -486,12 +489,11 @@ def setup_mesh_and_model(config):
logger,
)
else:
use_ocdbt = True
use_zarr3 = True
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
use_ocdbt = config.checkpoint_storage_use_ocdbt
use_zarr3 = config.checkpoint_storage_use_zarr3
if config.enable_single_controller:
use_ocdbt = False
use_zarr3 = False
use_ocdbt, use_zarr3 = False, False
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
config.checkpoint_dir,
config.enable_checkpointing,
Expand Down

0 comments on commit 94e6907

Please sign in to comment.