Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nit fixes and improvements to Trainer #933

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 78 additions & 66 deletions src/fairseq2/recipes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager, nullcontext
from itertools import count
from pathlib import Path
Expand All @@ -29,8 +29,9 @@
from fairseq2.checkpoint import CheckpointManager, CheckpointNotFoundError
from fairseq2.datasets import DataReader
from fairseq2.early_stopper import EarlyStopper
from fairseq2.error import ContractError, InternalError, InvalidOperationError
from fairseq2.gang import FakeGang, Gang, broadcast_flag
from fairseq2.logging import get_log_writer
from fairseq2.logging import log
from fairseq2.metrics import (
JsonFileMetricRecorder,
LogMetricRecorder,
Expand All @@ -57,9 +58,6 @@
from fairseq2.utils.rng import RngBag
from fairseq2.utils.state import FSDPOptimizerStateHandler, StatefulObjectBag

log = get_log_writer(__name__)


BatchT = TypeVar("BatchT")

BatchT_contra = TypeVar("BatchT_contra", contravariant=True)
Expand Down Expand Up @@ -207,6 +205,7 @@ def __init__(
keep_best_n_checkpoints: int | None = None,
keep_last_n_models: int | None = None,
keep_best_n_models: int | None = None,
metric_recorders: Iterable[MetricRecorder] | None = None,
tb_dir: Path | None = None,
metrics_dir: Path | None = None,
wandb_options: tuple[Path, str, str] | None = None,
Expand Down Expand Up @@ -287,13 +286,14 @@ def __init__(
The number of best checkpoint models to keep based on their
validation score. Must be greater than or equal to
``keep_best_n_checkpoints``.
:param metric_recorders:
The metric recorders.
:param tb_dir:
The TensorBoard log directory to dump metrics.
Legacy. Use ``metric_recorders``.
:param metrics_dir:
The directory to dump metrics.
Legacy. Use ``metric_recorders``.
:param wandb_options:
The directory, project name, and run name for Weights & Bias metric
logging.
Legacy. Use ``metric_recorders``.
:param publish_metrics_after_n_steps:
The number of steps after which to start publishing metrics.
:param publish_metrics_every_n_steps:
Expand Down Expand Up @@ -367,15 +367,17 @@ def __init__(

self.register_stateful("_step_nr", 0)

if max_num_steps == 0:
raise ValueError("`max_num_steps` must be greater than zero.")
if max_num_steps is not None:
if max_num_steps <= 0:
raise ValueError("`max_num_steps` must be greater than zero.")

self._max_num_steps = max_num_steps

self.register_stateful("_data_epoch_nr", 1)

if max_num_data_epochs == 0:
raise ValueError("`max_num_data_epochs` must be greater than zero.")
if max_num_data_epochs is not None:
if max_num_data_epochs <= 0:
raise ValueError("`max_num_data_epochs` must be greater than zero.")

self._max_num_data_epochs = max_num_data_epochs

Expand Down Expand Up @@ -429,48 +431,54 @@ def __init__(
"`valid_units` and `valid_data_readers` must be both specified."
)

if validate_every_n_steps == 0:
raise ValueError("`validate_every_n_steps` must be greater than zero.")
if validate_every_n_steps is not None:
if validate_every_n_steps <= 0:
raise ValueError("`validate_every_n_steps` must be greater than zero.")

self._validate_after_n_steps = validate_after_n_steps
self._validate_every_n_steps = validate_every_n_steps

if validate_every_n_data_epochs == 0:
raise ValueError(
"`validate_every_n_data_epochs` must be greater than zero."
)
if validate_every_n_data_epochs is not None:
if validate_every_n_data_epochs <= 0:
raise ValueError(
"`validate_every_n_data_epochs` must be greater than zero."
)

self._validate_after_n_data_epochs = validate_after_n_data_epochs
self._validate_every_n_data_epochs = validate_every_n_data_epochs

self._checkpoint_manager = checkpoint_manager

if checkpoint_every_n_steps == 0:
raise ValueError("`checkpoint_every_n_steps` must be greater than zero.")
if checkpoint_every_n_steps is not None:
if checkpoint_every_n_steps <= 0:
raise ValueError(
"`checkpoint_every_n_steps` must be greater than zero."
)

self._checkpoint_after_n_steps = checkpoint_after_n_steps
self._checkpoint_every_n_steps = checkpoint_every_n_steps

if checkpoint_every_n_data_epochs == 0:
raise ValueError(
"`checkpoint_every_n_data_epochs` must be greater than zero."
)
if checkpoint_every_n_data_epochs is not None:
if checkpoint_every_n_data_epochs <= 0:
raise ValueError(
"`checkpoint_every_n_data_epochs` must be greater than zero."
)

self._checkpoint_after_n_data_epochs = checkpoint_after_n_data_epochs
self._checkpoint_every_n_data_epochs = checkpoint_every_n_data_epochs

if keep_last_n_checkpoints is not None and keep_best_n_checkpoints is not None:
raise ValueError(
"`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time."
)

if keep_last_n_checkpoints == 0:
raise ValueError("`keep_last_n_checkpoints` must be greater than zero.")
if keep_last_n_checkpoints is not None:
if keep_best_n_checkpoints is not None:
raise ValueError(
"`keep_last_n_checkpoints` and `keep_best_n_checkpoints` are mutually exclusive and must not be specified at the same time."
)

if keep_best_n_checkpoints == 0:
raise ValueError("`keep_best_n_checkpoints` must be greater than zero.")
if keep_last_n_checkpoints <= 0:
raise ValueError("`keep_last_n_checkpoints` must be greater than zero.")
elif keep_best_n_checkpoints is not None:
if keep_best_n_checkpoints <= 0:
raise ValueError("`keep_best_n_checkpoints` must be greater than zero.")

if keep_best_n_checkpoints is not None:
if checkpoint_every_n_steps is not None:
if score_metric_name is None:
raise ValueError(
Expand Down Expand Up @@ -521,23 +529,27 @@ def __init__(

self._metric_bag = unit.metric_bag

if root_gang.rank == 0:
self._metric_recorders = [LogMetricRecorder(log)]
if metric_recorders is None:
# compat
if root_gang.rank == 0:
self._metric_recorders = [LogMetricRecorder(log)]

if tb_dir is not None:
self._metric_recorders.append(TensorBoardRecorder(tb_dir))
if tb_dir is not None:
self._metric_recorders.append(TensorBoardRecorder(tb_dir))

if metrics_dir is not None:
self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir))
if metrics_dir is not None:
self._metric_recorders.append(JsonFileMetricRecorder(metrics_dir))

if wandb_options is not None:
wandb_dir, wandb_project, wandb_name = wandb_options
if wandb_options is not None:
wandb_dir, wandb_project, wandb_name = wandb_options

self._metric_recorders.append(
WandbRecorder(wandb_project, wandb_name, wandb_dir)
)
self._metric_recorders.append(
WandbRecorder(wandb_project, wandb_name, wandb_dir)
)
else:
self._metric_recorders = []
else:
self._metric_recorders = []
self._metric_recorders = list(metric_recorders)

if publish_metrics_every_n_steps == 0:
raise ValueError(
Expand All @@ -559,16 +571,20 @@ def __init__(
if profile is not None and tb_dir is None:
log.warning("No TensorBoard log directory provided. Profiling will be disabled.") # fmt: skip

skip_first, active_steps = 1, 0
num_skip_steps, num_record_steps = 1, 0

profile_dir = Path()

enabled = False
else:
skip_first, active_steps = profile
num_skip_steps, num_record_steps = profile

profile_dir = tb_dir

enabled = num_record_steps > 0

self._profiler = Profiler(
skip_first, active_steps, profile_dir, root_gang, enabled=active_steps > 0
num_skip_steps, num_record_steps, profile_dir, root_gang, enabled=enabled
)

self._anomaly_detection = anomaly_detection
Expand All @@ -593,7 +609,7 @@ def request_stop(self) -> None:

def __call__(self) -> None:
if self._run:
raise RuntimeError("The trainer can only be run once.")
raise InvalidOperationError("The trainer can only be run once.")

self._run = True

Expand Down Expand Up @@ -729,8 +745,8 @@ def _run_step(self) -> None:

if num_batch_targets is not None:
if num_batch_targets == 0:
raise RuntimeError(
"The train unit returned zero loss targets. Please file a bug report to the recipe author."
raise ContractError(
"The train unit returned zero loss targets."
)

num_targets += num_batch_targets
Expand Down Expand Up @@ -865,9 +881,7 @@ def _publish_metrics(self) -> None:

if self._root_gang.rank == 0:
if values is None:
raise RuntimeError(
"The synchronized metric values are `None`. Please file a bug report."
)
raise InternalError("`values` is `None`.")

extend_batch_metrics(
values, self._num_effective_batches, self._total_step_time
Expand Down Expand Up @@ -974,9 +988,7 @@ def _publish_validation_metrics(
return None

if values is None:
raise RuntimeError(
"The synchronized validation metric values are `None`. Please file a bug report."
)
raise InternalError("`values` is `None`.")

extend_batch_metrics(values, num_batches, elapsed_time)

Expand Down Expand Up @@ -1019,8 +1031,8 @@ def _compute_valid_score(self, unit_scores: list[float]) -> float | None:

if not unit_scores:
if self._root_gang.rank == 0:
raise RuntimeError(
"None of the validation units returned a score metric value. Please file a bug report to the recipe author."
raise ContractError(
"None of the validation units returned a score metric value."
)

return None
Expand Down Expand Up @@ -1061,9 +1073,7 @@ def _maybe_request_early_stop(self) -> None:

if self._root_gang.rank == 0:
if self._valid_score is None:
raise RuntimeError(
"Early stop requested, but the validation score is `None`. Please file a bug report."
)
raise InternalError("Early stopping, but `_valid_score` is `None`.")

should_stop = self._early_stopper(self._step_nr, self._valid_score)
else:
Expand Down Expand Up @@ -1131,7 +1141,8 @@ def _checkpoint(self) -> None:
nm = self._keep_last_n_models

if nm is not None:
assert nc is not None
if nc is None:
raise InternalError("`_keep_last_n_checkpoints` is `None`")

self._checkpoint_manager.keep_last_n_checkpoints(nm)
self._checkpoint_manager.keep_last_n_checkpoints(nc, preserve_model=True)
Expand All @@ -1142,7 +1153,8 @@ def _checkpoint(self) -> None:
nm = self._keep_best_n_models

if nm is not None:
assert nc is not None
if nc is None:
raise InternalError("`_keep_best_n_checkpoints` is `None`")

self._checkpoint_manager.keep_best_n_checkpoints(
nm, lower_better=self._lower_better
Expand Down
10 changes: 6 additions & 4 deletions src/fairseq2/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ def state_type_error(name: str, state: object) -> TypeError:
is_explicit, state_handler = self._is_explicit(name)

if is_explicit:
state = state_dict_.pop(name, None)
if state is None:
try:
state = state_dict_.pop(name)
except KeyError:
missing_stateful_attrs.append(name)

continue
Expand All @@ -191,8 +192,9 @@ def state_type_error(name: str, state: object) -> TypeError:
except (ValueError, TypeError) as ex:
raise state_error(name, obj) from ex
elif isinstance(obj, Stateful) and not self._is_dunder(name):
state = state_dict_.pop(name, None)
if state is None:
try:
state = state_dict_.pop(name)
except KeyError:
missing_stateful_attrs.append(name)

continue
Expand Down
Loading