Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 24, 2025
1 parent 3cc977f commit b8f434f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import yaml

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

Expand Down Expand Up @@ -183,7 +183,7 @@ class CheckpointPathConfigBase(CheckpointConfigBase):
default=None,
desc="Custom timeout for lengthy operations.",
hint=FieldHint.optional,
valid=check_field(Assert.gt, 0),
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)


Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class TrainingConfig(Config):
desc="Timeout for lengthy operations such as checkpoint saving and loading,"
" and dataset preparation and sampling.",
hint=FieldHint.feature,
valid=check_field(Assert.gt, 0),
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
)

def _validate(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,9 @@ def _save_checkpoint(
}
if metrics is not None:
metadata["metrics"] = {key.value: value for key, value in metrics.items()}
self._multi_stage.save_checkpoint(config.get_save_config(checkpoint_directory), metadata)
self._multi_stage.save_checkpoint(
config.get_save_config(checkpoint_directory, timeout=self._config.training.timeout), metadata
)

# Barrier to ensure everyone is done.
safe_barrier(
Expand Down Expand Up @@ -447,7 +449,9 @@ def _load_checkpoint(self, config: TrainingCheckpointConfig, iteration: int) ->
checkpoint_directory = config.get_save_directory(self._run.experiment_directory) / str(iteration)
Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok")

metadata = self._multi_stage.load_checkpoint(config.get_load_config(checkpoint_directory))
metadata = self._multi_stage.load_checkpoint(
config.get_load_config(checkpoint_directory, timeout=self._config.training.timeout)
)
self._optimizer.load(metadata["optimizer"])
if "schedules" in metadata:
# Backward compatibility.
Expand Down

0 comments on commit b8f434f

Please sign in to comment.