diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index ace25adb0..4e8ceec78 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from contextlib import AbstractContextManager, nullcontext from itertools import count from pathlib import Path @@ -29,8 +29,9 @@ from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError from fairseq2.datasets import DataReader from fairseq2.early_stopper import EarlyStopper +from fairseq2.error import ContractError, InternalError, InvalidOperationError from fairseq2.gang import FakeGang, Gang, broadcast_flag -from fairseq2.logging import get_log_writer +from fairseq2.logging import log from fairseq2.metrics import ( JsonFileMetricRecorder, LogMetricRecorder, @@ -57,9 +58,6 @@ from fairseq2.utils.rng import RngBag from fairseq2.utils.state import FSDPOptimizerStateHandler, StatefulObjectBag -log = get_log_writer(__name__) - - BatchT = TypeVar("BatchT") BatchT_contra = TypeVar("BatchT_contra", contravariant=True) @@ -207,6 +205,7 @@ def __init__( keep_best_n_checkpoints: int | None = None, keep_last_n_models: int | None = None, keep_best_n_models: int | None = None, + metric_recorders: Iterable[MetricRecorder] | None = None, tb_dir: Path | None = None, metrics_dir: Path | None = None, wandb_options: tuple[Path, str, str] | None = None, @@ -287,13 +286,14 @@ def __init__( The number of best checkpoint models to keep based on their validation score. Must be greater than or equal to ``keep_best_n_checkpoints``. + :param metric_recorders: + The metric recorders. :param tb_dir: - The TensorBoard log directory to dump metrics. + Legacy. Use ``metric_recorders``. :param metrics_dir: - The directory to dump metrics. + Legacy. Use ``metric_recorders``. :param wandb_options: - The directory, project name, and run name for Weights & Bias metric - logging. + Legacy. Use ``metric_recorders``. :param publish_metrics_after_n_steps: The number of steps after which to start publishing metrics. :param publish_metrics_every_n_steps: @@ -367,15 +367,17 @@ def __init__( self.register_stateful("_step_nr", 0) - if max_num_steps == 0: - raise ValueError("`max_num_steps` must be greater than zero.") + if max_num_steps is not None: + if max_num_steps <= 0: + raise ValueError("`max_num_steps` must be greater than zero.") self._max_num_steps = max_num_steps self.register_stateful("_data_epoch_nr", 1) - if max_num_data_epochs == 0: - raise ValueError("`max_num_data_epochs` must be greater than zero.") + if max_num_data_epochs is not None: + if max_num_data_epochs <= 0: + raise ValueError("`max_num_data_epochs` must be greater than zero.") self._max_num_data_epochs = max_num_data_epochs @@ -429,48 +431,54 @@ def __init__( "`valid_units` and `valid_data_readers` must be both specified." ) - if validate_every_n_steps == 0: - raise ValueError("`validate_every_n_steps` must be greater than zero.") + if validate_every_n_steps is not None: + if validate_every_n_steps <= 0: + raise ValueError("`validate_every_n_steps` must be greater than zero.") self._validate_after_n_steps = validate_after_n_steps self._validate_every_n_steps = validate_every_n_steps - if validate_every_n_data_epochs == 0: - raise ValueError( - "`validate_every_n_data_epochs` must be greater than zero." - ) + if validate_every_n_data_epochs is not None: + if validate_every_n_data_epochs <= 0: + raise ValueError( + "`validate_every_n_data_epochs` must be greater than zero." + ) self._validate_after_n_data_epochs = validate_after_n_data_epochs self._validate_every_n_data_epochs = validate_every_n_data_epochs self._checkpoint_manager = checkpoint_manager - if checkpoint_every_n_steps == 0: - raise ValueError("`checkpoint_every_n_steps` must be greater than zero.") + if checkpoint_every_n_steps is not None: + if checkpoint_every_n_steps <= 0: + raise ValueError( + "`checkpoint_every_n_steps` must be greater than zero." + ) self._checkpoint_after_n_steps = checkpoint_after_n_steps self._checkpoint_every_n_steps = checkpoint_every_n_steps - if checkpoint_every_n_data_epochs == 0: - raise ValueError( - "`checkpoint_every_n_data_epochs` must be greater than zero." - ) + if checkpoint_every_n_data_epochs is not None: + if checkpoint_every_n_data_epochs <= 0: + raise ValueError( + "`checkpoint_every_n_data_epochs` must be greater than zero." + ) self._checkpoint_after_n_data_epochs = checkpoint_after_n_data_epochs self._checkpoint_every_n_data_epochs = checkpoint_every_n_data_epochs - if keep_last_n_checkpoints is not None and keep_best_n_checkpoints is not None: - raise ValueError( - "`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time." - ) - - if keep_last_n_checkpoints == 0: - raise ValueError("`keep_last_n_checkpoints` must be greater than zero.") + if keep_last_n_checkpoints is not None: + if keep_best_n_checkpoints is not None: + raise ValueError( + "`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time." + ) - if keep_best_n_checkpoints == 0: - raise ValueError("`keep_best_n_checkpoints` must be greater than zero.") + if keep_last_n_checkpoints <= 0: + raise ValueError("`keep_last_n_checkpoints` must be greater than zero.") + elif keep_best_n_checkpoints is not None: + if keep_best_n_checkpoints <= 0: + raise ValueError("`keep_best_n_checkpoints` must be greater than zero.") - if keep_best_n_checkpoints is not None: if checkpoint_every_n_steps is not None: if score_metric_name is None: raise ValueError( @@ -521,23 +529,27 @@ def __init__( self._metric_bag = unit.metric_bag - if root_gang.rank == 0: - self._metric_recorders = [LogMetricRecorder(log)] + if metric_recorders is None: + # compat + if root_gang.rank == 0: + self._metric_recorders = [LogMetricRecorder(log)] - if tb_dir is not None: - self._metric_recorders.append(TensorBoardRecorder(tb_dir)) + if tb_dir is not None: + self._metric_recorders.append(TensorBoardRecorder(tb_dir)) - if metrics_dir is not None: - self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir)) + if metrics_dir is not None: + self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir)) - if wandb_options is not None: - wandb_dir, wandb_project, wandb_name = wandb_options + if wandb_options is not None: + wandb_dir, wandb_project, wandb_name = wandb_options - self._metric_recorders.append( - WandbRecorder(wandb_project, wandb_name, wandb_dir) - ) + self._metric_recorders.append( + WandbRecorder(wandb_project, wandb_name, wandb_dir) + ) + else: + self._metric_recorders = [] else: - self._metric_recorders = [] + self._metric_recorders = list(metric_recorders) if publish_metrics_every_n_steps == 0: raise ValueError( @@ -559,16 +571,20 @@ def __init__( if profile is not None and tb_dir is None: log.warning("No TensorBoard log directory provided. Profiling will be disabled.") # fmt: skip - skip_first, active_steps = 1, 0 + num_skip_steps, num_record_steps = 1, 0 profile_dir = Path() + + enabled = False else: - skip_first, active_steps = profile + num_skip_steps, num_record_steps = profile profile_dir = tb_dir + enabled = num_record_steps > 0 + self._profiler = Profiler( - skip_first, active_steps, profile_dir, root_gang, enabled=active_steps > 0 + num_skip_steps, num_record_steps, profile_dir, root_gang, enabled=enabled ) self._anomaly_detection = anomaly_detection @@ -593,7 +609,7 @@ def request_stop(self) -> None: def __call__(self) -> None: if self._run: - raise RuntimeError("The trainer can only be run once.") + raise InvalidOperationError("The trainer can only be run once.") self._run = True @@ -729,8 +745,8 @@ def _run_step(self) -> None: if num_batch_targets is not None: if num_batch_targets == 0: - raise RuntimeError( - "The train unit returned zero loss targets. Please file a bug report to the recipe author." + raise ContractError( + "The train unit returned zero loss targets." ) num_targets += num_batch_targets @@ -865,9 +881,7 @@ def _publish_metrics(self) -> None: if self._root_gang.rank == 0: if values is None: - raise RuntimeError( - "The synchronized metric values are `None`. Please file a bug report." - ) + raise InternalError("`values` is `None`.") extend_batch_metrics( values, self._num_effective_batches, self._total_step_time @@ -974,9 +988,7 @@ def _publish_validation_metrics( return None if values is None: - raise RuntimeError( - "The synchronized validation metric values are `None`. Please file a bug report." - ) + raise InternalError("`values` is `None`.") extend_batch_metrics(values, num_batches, elapsed_time) @@ -1019,8 +1031,8 @@ def _compute_valid_score(self, unit_scores: list[float]) -> float | None: if not unit_scores: if self._root_gang.rank == 0: - raise RuntimeError( - "None of the validation units returned a score metric value. Please file a bug report to the recipe author." + raise ContractError( + "None of the validation units returned a score metric value." ) return None @@ -1061,9 +1073,7 @@ def _maybe_request_early_stop(self) -> None: if self._root_gang.rank == 0: if self._valid_score is None: - raise RuntimeError( - "Early stop requested, but the validation score is `None`. Please file a bug report." - ) + raise InternalError("Early stopping, but `_valid_score` is `None`.") should_stop = self._early_stopper(self._step_nr, self._valid_score) else: @@ -1131,7 +1141,8 @@ def _checkpoint(self) -> None: nm = self._keep_last_n_models if nm is not None: - assert nc is not None + if nc is None: + raise InternalError("`_keep_last_n_checkpoints` is `None`") self._checkpoint_manager.keep_last_n_checkpoints(nm) self._checkpoint_manager.keep_last_n_checkpoints(nc, preserve_model=True) @@ -1142,7 +1153,8 @@ def _checkpoint(self) -> None: nm = self._keep_best_n_models if nm is not None: - assert nc is not None + if nc is None: + raise InternalError("`_keep_best_n_checkpoints` is `None`") self._checkpoint_manager.keep_best_n_checkpoints( nm, lower_better=self._lower_better diff --git a/src/fairseq2/utils/state.py b/src/fairseq2/utils/state.py index 02bcb07f5..1bc74e12a 100644 --- a/src/fairseq2/utils/state.py +++ b/src/fairseq2/utils/state.py @@ -168,8 +168,9 @@ def state_type_error(name: str, state: object) -> TypeError: is_explicit, state_handler = self._is_explicit(name) if is_explicit: - state = state_dict_.pop(name, None) - if state is None: + try: + state = state_dict_.pop(name) + except KeyError: missing_stateful_attrs.append(name) continue @@ -191,8 +192,9 @@ def state_type_error(name: str, state: object) -> TypeError: except (ValueError, TypeError) as ex: raise state_error(name, obj) from ex elif isinstance(obj, Stateful) and not self._is_dunder(name): - state = state_dict_.pop(name, None) - if state is None: + try: + state = state_dict_.pop(name) + except KeyError: missing_stateful_attrs.append(name) continue