Skip to content

Commit

Permalink
Set a fixed stage in the evaluation loops (Lightning-AI#17094)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 27, 2023
1 parent 260c1bd commit 8b1baf2
Show file tree
Hide file tree
Showing 15 changed files with 52 additions and 102 deletions.
41 changes: 19 additions & 22 deletions src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_resolve_overfit_batches,
)
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import SIGTERMException
Expand All @@ -51,7 +51,14 @@
class _EvaluationLoop(_Loop):
"""Top-level loop where validation/testing starts."""

def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: bool = True) -> None:
def __init__(
self,
trainer: "pl.Trainer",
trainer_fn: TrainerFn,
stage: RunningStage,
verbose: bool = True,
inference_mode: bool = True,
) -> None:
super().__init__(trainer)
self.verbose = verbose
self.inference_mode = inference_mode
Expand All @@ -61,7 +68,9 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode:
self._results = _ResultCollection(training=False)
self._logged_outputs: List[_OUT_DICT] = []
self._has_run: bool = False
self._data_source = _DataLoaderSource(None, "")
self._trainer_fn = trainer_fn
self._stage = stage
self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader")
self._combined_loader: Optional[CombinedLoader] = None
self._data_fetcher: Optional[_DataFetcher] = None
self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int)
Expand Down Expand Up @@ -123,9 +132,7 @@ def run(self) -> List[_OUT_DICT]:

def setup_data(self) -> None:
trainer = self.trainer
trainer_fn = trainer.state.fn
assert trainer_fn is not None

trainer_fn = self._trainer_fn
if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl:
return

Expand All @@ -143,9 +150,7 @@ def setup_data(self) -> None:
):
self._last_val_dl_reload_epoch = trainer.current_epoch

stage = trainer.state.stage
assert stage is not None

stage = self._stage
source = self._data_source
dataloaders = _request_dataloader(source)
trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()")
Expand All @@ -166,7 +171,7 @@ def setup_data(self) -> None:
self._max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dl = _process_dataloader(trainer, trainer_fn, stage, dl)
dataloaders.append(dl)

# determine number of batches
Expand Down Expand Up @@ -261,8 +266,7 @@ def on_run_end(self) -> List[_OUT_DICT]:
self._on_evaluation_model_train()

if self.verbose and self.trainer.is_global_zero:
assert self.trainer.state.stage is not None
self._print_results(logged_outputs, self.trainer.state.stage)
self._print_results(logged_outputs, self._stage)

return logged_outputs

Expand Down Expand Up @@ -333,18 +337,12 @@ def _store_dataloader_outputs(self) -> None:
self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics())

def _on_before_fetch(self) -> None:
stage = self.trainer.state.stage
assert stage is not None
stage = stage.dataloader_prefix
self.trainer.profiler.start(f"[{type(self).__name__}].{stage}_next")
self.trainer.profiler.start(f"[{type(self).__name__}].{self._stage.dataloader_prefix}_next")

def _on_after_fetch(self) -> None:
stage = self.trainer.state.stage
assert stage is not None
stage = stage.dataloader_prefix
# the dataloader_idx cannot be easily included here because it might be different from the index used on
# profiler start, since the `__next__` call might use a different iterator
self.trainer.profiler.stop(f"[{type(self).__name__}].{stage}_next")
self.trainer.profiler.stop(f"[{type(self).__name__}].{self._stage.dataloader_prefix}_next")

def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Runs the actual evaluation step together with all the necessary bookkeeping and the hooks tied to it.
Expand Down Expand Up @@ -417,11 +415,10 @@ def _verify_dataloader_idx_requirement(self) -> None:
batch_start_hook = "on_test_batch_start" if trainer.testing else "on_validation_batch_start"
batch_end_hook = "on_test_batch_end" if trainer.testing else "on_validation_batch_end"
assert self._combined_loader is not None
assert trainer.state.stage is not None
_verify_dataloader_idx_requirement(
(step_hook, batch_start_hook, batch_end_hook),
self._combined_loader._mode == "sequential" and self.num_dataloaders > 1,
trainer.state.stage,
self._stage,
trainer.lightning_module,
)

Expand Down
5 changes: 2 additions & 3 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,11 @@ def setup_data(self) -> None:
_resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING)

trainer_fn = TrainerFn.FITTING
stage = RunningStage.TRAINING
dataloaders = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dl = _process_dataloader(trainer, trainer_fn, stage, dl)
dataloaders.append(dl)
combined_loader.flattened = dataloaders
self._combined_loader = combined_loader
Expand All @@ -247,7 +248,6 @@ def setup_data(self) -> None:
if self.max_batches == 0:
return

stage = RunningStage.TRAINING
self.max_batches = _parse_num_batches(stage, self.max_batches, trainer.limit_train_batches)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
Expand Down Expand Up @@ -303,7 +303,6 @@ def on_run_start(self) -> None:

# reload the evaluation dataloaders too for proper display in the progress bar
if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None:
# TODO(carmocca): avoid having to set validating
trainer.validating = True
self.epoch_loop.val_loop.setup_data()
trainer.training = True
Expand Down
5 changes: 2 additions & 3 deletions src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def setup_data(self) -> None:
self.max_batches = []
for dl in combined_loader.flattened:
_check_dataloader_iterable(dl, source, trainer_fn)
dl = _process_dataloader(trainer, dl)
dl = _process_dataloader(trainer, trainer_fn, stage, dl)
dataloaders.append(dl)

# determine number of batches
Expand Down Expand Up @@ -336,10 +336,9 @@ def _on_predict_end(self) -> None:
def _verify_dataloader_idx_requirement(self) -> None:
trainer = self.trainer
assert self._combined_loader is not None
assert trainer.state.stage is not None
_verify_dataloader_idx_requirement(
("predict_step", "on_predict_batch_start", "on_predict_batch_end"),
self._combined_loader._mode == "sequential" and self.num_dataloaders > 1,
trainer.state.stage,
RunningStage.PREDICTING,
trainer.lightning_module,
)
6 changes: 5 additions & 1 deletion src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightning.pytorch.loops.utilities import _is_max_limit_reached
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
Expand Down Expand Up @@ -69,7 +70,9 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s
self.automatic_optimization = _AutomaticOptimization(trainer)
self.manual_optimization = _ManualOptimization(trainer)

self.val_loop = loops._EvaluationLoop(trainer, verbose=False, inference_mode=False)
self.val_loop = loops._EvaluationLoop(
trainer, TrainerFn.FITTING, RunningStage.VALIDATING, verbose=False, inference_mode=False
)

self._results = _ResultCollection(training=True)
self._warning_cache = WarningCache()
Expand Down Expand Up @@ -244,6 +247,7 @@ def on_advance_end(self) -> None:
# -----------------------------------------
should_check_val = self._should_check_val_fx()
if should_check_val:
# this needs to be set so the correct `trainer._active_loop` is picked
self.trainer.validating = True
self.val_loop.run()
self.trainer.training = True
Expand Down
15 changes: 3 additions & 12 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,8 @@ def attach_dataloaders(
trainer.fit_loop.epoch_loop.val_loop._data_source.instance = (
val_dataloaders if val_dataloaders is not None else model
)
trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader"
trainer.validate_loop._data_source.instance = val_dataloaders if val_dataloaders is not None else model
trainer.validate_loop._data_source.name = "val_dataloader"
trainer.test_loop._data_source.instance = test_dataloaders if test_dataloaders is not None else model
trainer.test_loop._data_source.name = "test_dataloader"
trainer.predict_loop._data_source.instance = predict_dataloaders if predict_dataloaders is not None else model

def attach_datamodule(
Expand All @@ -160,11 +157,8 @@ def attach_datamodule(
trainer = self.trainer
trainer.fit_loop._data_source.instance = datamodule
trainer.fit_loop.epoch_loop.val_loop._data_source.instance = datamodule
trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader"
trainer.validate_loop._data_source.instance = datamodule
trainer.validate_loop._data_source.name = "val_dataloader"
trainer.test_loop._data_source.instance = datamodule
trainer.test_loop._data_source.name = "test_dataloader"
trainer.predict_loop._data_source.instance = datamodule

trainer.datamodule = datamodule
Expand Down Expand Up @@ -465,12 +459,9 @@ def _parse_num_batches(
return num_batches


def _process_dataloader(trainer: "pl.Trainer", dataloader: object) -> object:
trainer_fn = trainer.state.fn
stage = trainer.state.stage
if trainer_fn is None or stage is None:
raise RuntimeError("Unexpected state")

def _process_dataloader(
trainer: "pl.Trainer", trainer_fn: TrainerFn, stage: RunningStage, dataloader: object
) -> object:
if stage != RunningStage.TRAINING:
is_shuffled = _is_dataloader_shuffled(dataloader)
# limit this warning only for samplers assigned automatically when shuffle is set
Expand Down
10 changes: 4 additions & 6 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,10 @@ def __init__(
# init loops
self.fit_loop = _FitLoop(self, min_epochs=min_epochs, max_epochs=max_epochs)
self.fit_loop.epoch_loop = _TrainingEpochLoop(self, min_steps=min_steps, max_steps=max_steps)
self.validate_loop = _EvaluationLoop(self, inference_mode=inference_mode)
self.test_loop = _EvaluationLoop(self, inference_mode=inference_mode)
self.validate_loop = _EvaluationLoop(
self, TrainerFn.VALIDATING, RunningStage.VALIDATING, inference_mode=inference_mode
)
self.test_loop = _EvaluationLoop(self, TrainerFn.TESTING, RunningStage.TESTING, inference_mode=inference_mode)
self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode)

self.accumulate_grad_batches = accumulate_grad_batches
Expand Down Expand Up @@ -1548,11 +1550,7 @@ def configure_optimizers(self):

if self.train_dataloader is None:
rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.")
state = self.state
self.state.fn = TrainerFn.FITTING
self.training = True
self.fit_loop.setup_data()
self.state = state

total_batches = self.num_training_batches

Expand Down
1 change: 0 additions & 1 deletion tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def test_ddp_strategy_set_timeout(mock_init_process_group):
strategy=ddp_strategy,
)
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_environment()
Expand Down
8 changes: 0 additions & 8 deletions tests/tests_pytorch/strategies/test_single_device_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,12 @@ def test_process_dataloader_gets_called_as_expected(keyword, value, monkeypatch)
monkeypatch.setattr(strategy, "process_dataloader", process_dataloader_mock)

if "train" in keyword:
trainer.state.fn = "fit"
trainer.training = True
fn = trainer.fit_loop.setup_data
elif "val" in keyword:
trainer.state.fn = "validate"
trainer.validating = True
fn = trainer.validate_loop.setup_data
elif "test" in keyword:
trainer.state.fn = "test"
trainer.testing = True
fn = trainer.test_loop.setup_data
else:
trainer.state.fn = "predict"
trainer.predicting = True
fn = trainer.predict_loop.setup_data

trainer._data_connector.attach_dataloaders(model, **{keyword: value})
Expand Down
8 changes: 0 additions & 8 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,6 @@ def test_error_raised_with_float_limited_eval_batches():
trainer = Trainer(limit_val_batches=limit_val_batches)
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model)
trainer.state.fn = TrainerFn.VALIDATING
trainer.state.stage = RunningStage.VALIDATING
with pytest.raises(
MisconfigurationException,
match=rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
Expand Down Expand Up @@ -408,8 +406,6 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl, wa
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
context = pytest.warns if warns else no_warning_call
trainer.state.fn = TrainerFn.VALIDATING
trainer.state.stage = RunningStage.VALIDATING
with context(PossibleUserWarning, match="recommended .* turn shuffling off for val/test"):
trainer.validate_loop.setup_data()

Expand Down Expand Up @@ -542,12 +538,10 @@ def test_eval_distributed_sampler_warning(devices, warn_context):
trainer._data_connector.attach_data(model)

trainer.state.fn = TrainerFn.VALIDATING
trainer.state.stage = RunningStage.VALIDATING
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.validate_loop.setup_data()

trainer.state.fn = TrainerFn.TESTING
trainer.state.stage = RunningStage.TESTING
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.test_loop.setup_data()

Expand All @@ -564,8 +558,6 @@ def val_dataloader(self):
model = CustomModel()
trainer.strategy.connect(model)
trainer._data_connector.attach_data(model)
trainer.state.fn = TrainerFn.FITTING
trainer.state.stage = RunningStage.VALIDATING
trainer.fit_loop.epoch_loop.val_loop.setup_data()
assert trainer.val_dataloaders.sampler.shuffle == shuffle

Expand Down
25 changes: 13 additions & 12 deletions tests/tests_pytorch/trainer/flags/test_limit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.states import TrainerFn


def test_num_dataloader_batches(tmpdir):
Expand Down Expand Up @@ -46,15 +46,15 @@ def test_num_dataloader_batches(tmpdir):


@pytest.mark.parametrize(
["stage", "mode"],
"mode",
[
(RunningStage.VALIDATING, "val"),
(RunningStage.TESTING, "test"),
(RunningStage.PREDICTING, "predict"),
"val",
"test",
"predict",
],
)
@pytest.mark.parametrize("limit_batches", [0.1, 10])
def test_eval_limit_batches(stage, mode, limit_batches):
def test_eval_limit_batches(mode, limit_batches):
limit_eval_batches = f"limit_{mode}_batches"
dl_hook = f"{mode}_dataloader"
model = BoringModel()
Expand All @@ -65,16 +65,17 @@ def test_eval_limit_batches(stage, mode, limit_batches):
trainer.strategy.connect(model)
trainer._data_connector.attach_dataloaders(model)

trainer.state.stage = stage
trainer.state.fn = stage.value
trainer._active_loop.setup_data()
if stage == RunningStage.VALIDATING:
if mode == "val":
trainer.validate_loop.setup_data()
trainer.state.fn = TrainerFn.VALIDATING
loader_num_batches = trainer.num_val_batches
dataloaders = trainer.val_dataloaders
elif stage == RunningStage.TESTING:
elif mode == "test":
trainer.test_loop.setup_data()
loader_num_batches = trainer.num_test_batches
dataloaders = trainer.test_dataloaders
elif stage == RunningStage.PREDICTING:
elif mode == "predict":
trainer.predict_loop.setup_data()
loader_num_batches = trainer.num_predict_batches
dataloaders = trainer.predict_dataloaders

Expand Down
Loading

0 comments on commit 8b1baf2

Please sign in to comment.