diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index fa78972ee4d0f..32e3472d03473 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -37,7 +37,7 @@ _resolve_overfit_batches, ) from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection -from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import SIGTERMException @@ -51,7 +51,14 @@ class _EvaluationLoop(_Loop): """Top-level loop where validation/testing starts.""" - def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: bool = True) -> None: + def __init__( + self, + trainer: "pl.Trainer", + trainer_fn: TrainerFn, + stage: RunningStage, + verbose: bool = True, + inference_mode: bool = True, + ) -> None: super().__init__(trainer) self.verbose = verbose self.inference_mode = inference_mode @@ -61,7 +68,9 @@ def __init__(self, trainer: "pl.Trainer", verbose: bool = True, inference_mode: self._results = _ResultCollection(training=False) self._logged_outputs: List[_OUT_DICT] = [] self._has_run: bool = False - self._data_source = _DataLoaderSource(None, "") + self._trainer_fn = trainer_fn + self._stage = stage + self._data_source = _DataLoaderSource(None, f"{stage.dataloader_prefix}_dataloader") self._combined_loader: Optional[CombinedLoader] = None self._data_fetcher: Optional[_DataFetcher] = None self._seen_batches_per_dataloader: DefaultDict[int, int] = defaultdict(int) @@ -123,9 +132,7 @@ def run(self) -> List[_OUT_DICT]: def setup_data(self) -> None: trainer = self.trainer - trainer_fn = trainer.state.fn - assert trainer_fn is not None - + trainer_fn = self._trainer_fn if self._combined_loader is not None and trainer_fn == "fit" and not self._should_reload_val_dl: return @@ -143,9 +150,7 @@ def setup_data(self) -> None: ): self._last_val_dl_reload_epoch = trainer.current_epoch - stage = trainer.state.stage - assert stage is not None - + stage = self._stage source = self._data_source dataloaders = _request_dataloader(source) trainer.strategy.barrier(f"{stage.dataloader_prefix}_dataloader()") @@ -166,7 +171,7 @@ def setup_data(self) -> None: self._max_batches = [] for dl in combined_loader.flattened: _check_dataloader_iterable(dl, source, trainer_fn) - dl = _process_dataloader(trainer, dl) + dl = _process_dataloader(trainer, trainer_fn, stage, dl) dataloaders.append(dl) # determine number of batches @@ -261,8 +266,7 @@ def on_run_end(self) -> List[_OUT_DICT]: self._on_evaluation_model_train() if self.verbose and self.trainer.is_global_zero: - assert self.trainer.state.stage is not None - self._print_results(logged_outputs, self.trainer.state.stage) + self._print_results(logged_outputs, self._stage) return logged_outputs @@ -333,18 +337,12 @@ def _store_dataloader_outputs(self) -> None: self._logged_outputs.append(trainer._logger_connector.update_eval_epoch_metrics()) def _on_before_fetch(self) -> None: - stage = self.trainer.state.stage - assert stage is not None - stage = stage.dataloader_prefix - self.trainer.profiler.start(f"[{type(self).__name__}].{stage}_next") + self.trainer.profiler.start(f"[{type(self).__name__}].{self._stage.dataloader_prefix}_next") def _on_after_fetch(self) -> None: - stage = self.trainer.state.stage - assert stage is not None - stage = stage.dataloader_prefix # the dataloader_idx cannot be easily included here because it might be different from the index used on # profiler start, since the `__next__` call might use a different iterator - self.trainer.profiler.stop(f"[{type(self).__name__}].{stage}_next") + self.trainer.profiler.stop(f"[{type(self).__name__}].{self._stage.dataloader_prefix}_next") def _evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """Runs the actual evaluation step together with all the necessary bookkeeping and the hooks tied to it. @@ -417,11 +415,10 @@ def _verify_dataloader_idx_requirement(self) -> None: batch_start_hook = "on_test_batch_start" if trainer.testing else "on_validation_batch_start" batch_end_hook = "on_test_batch_end" if trainer.testing else "on_validation_batch_end" assert self._combined_loader is not None - assert trainer.state.stage is not None _verify_dataloader_idx_requirement( (step_hook, batch_start_hook, batch_end_hook), self._combined_loader._mode == "sequential" and self.num_dataloaders > 1, - trainer.state.stage, + self._stage, trainer.lightning_module, ) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 4149a11ad34ed..c317e92840abe 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -230,10 +230,11 @@ def setup_data(self) -> None: _resolve_overfit_batches(combined_loader, mode=RunningStage.TRAINING) trainer_fn = TrainerFn.FITTING + stage = RunningStage.TRAINING dataloaders = [] for dl in combined_loader.flattened: _check_dataloader_iterable(dl, source, trainer_fn) - dl = _process_dataloader(trainer, dl) + dl = _process_dataloader(trainer, trainer_fn, stage, dl) dataloaders.append(dl) combined_loader.flattened = dataloaders self._combined_loader = combined_loader @@ -247,7 +248,6 @@ def setup_data(self) -> None: if self.max_batches == 0: return - stage = RunningStage.TRAINING self.max_batches = _parse_num_batches(stage, self.max_batches, trainer.limit_train_batches) # store epoch of dataloader reset for reload_dataloaders_every_n_epochs @@ -303,7 +303,6 @@ def on_run_start(self) -> None: # reload the evaluation dataloaders too for proper display in the progress bar if self.epoch_loop._should_check_val_epoch() and trainer.val_dataloaders is None: - # TODO(carmocca): avoid having to set validating trainer.validating = True self.epoch_loop.val_loop.setup_data() trainer.training = True diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index b6ee280f88296..36ccf564ca1be 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -142,7 +142,7 @@ def setup_data(self) -> None: self.max_batches = [] for dl in combined_loader.flattened: _check_dataloader_iterable(dl, source, trainer_fn) - dl = _process_dataloader(trainer, dl) + dl = _process_dataloader(trainer, trainer_fn, stage, dl) dataloaders.append(dl) # determine number of batches @@ -336,10 +336,9 @@ def _on_predict_end(self) -> None: def _verify_dataloader_idx_requirement(self) -> None: trainer = self.trainer assert self._combined_loader is not None - assert trainer.state.stage is not None _verify_dataloader_idx_requirement( ("predict_step", "on_predict_batch_start", "on_predict_batch_end"), self._combined_loader._mode == "sequential" and self.num_dataloaders > 1, - trainer.state.stage, + RunningStage.PREDICTING, trainer.lightning_module, ) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index fe95beb636ef4..208faf8f16d05 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -25,6 +25,7 @@ from lightning.pytorch.loops.utilities import _is_max_limit_reached from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection +from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature @@ -69,7 +70,9 @@ def __init__(self, trainer: "pl.Trainer", min_steps: Optional[int] = None, max_s self.automatic_optimization = _AutomaticOptimization(trainer) self.manual_optimization = _ManualOptimization(trainer) - self.val_loop = loops._EvaluationLoop(trainer, verbose=False, inference_mode=False) + self.val_loop = loops._EvaluationLoop( + trainer, TrainerFn.FITTING, RunningStage.VALIDATING, verbose=False, inference_mode=False + ) self._results = _ResultCollection(training=True) self._warning_cache = WarningCache() @@ -244,6 +247,7 @@ def on_advance_end(self) -> None: # ----------------------------------------- should_check_val = self._should_check_val_fx() if should_check_val: + # this needs to be set so the correct `trainer._active_loop` is picked self.trainer.validating = True self.val_loop.run() self.trainer.training = True diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index e1f6628509dd8..bd0238f463e2f 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -141,11 +141,8 @@ def attach_dataloaders( trainer.fit_loop.epoch_loop.val_loop._data_source.instance = ( val_dataloaders if val_dataloaders is not None else model ) - trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader" trainer.validate_loop._data_source.instance = val_dataloaders if val_dataloaders is not None else model - trainer.validate_loop._data_source.name = "val_dataloader" trainer.test_loop._data_source.instance = test_dataloaders if test_dataloaders is not None else model - trainer.test_loop._data_source.name = "test_dataloader" trainer.predict_loop._data_source.instance = predict_dataloaders if predict_dataloaders is not None else model def attach_datamodule( @@ -160,11 +157,8 @@ def attach_datamodule( trainer = self.trainer trainer.fit_loop._data_source.instance = datamodule trainer.fit_loop.epoch_loop.val_loop._data_source.instance = datamodule - trainer.fit_loop.epoch_loop.val_loop._data_source.name = "val_dataloader" trainer.validate_loop._data_source.instance = datamodule - trainer.validate_loop._data_source.name = "val_dataloader" trainer.test_loop._data_source.instance = datamodule - trainer.test_loop._data_source.name = "test_dataloader" trainer.predict_loop._data_source.instance = datamodule trainer.datamodule = datamodule @@ -465,12 +459,9 @@ def _parse_num_batches( return num_batches -def _process_dataloader(trainer: "pl.Trainer", dataloader: object) -> object: - trainer_fn = trainer.state.fn - stage = trainer.state.stage - if trainer_fn is None or stage is None: - raise RuntimeError("Unexpected state") - +def _process_dataloader( + trainer: "pl.Trainer", trainer_fn: TrainerFn, stage: RunningStage, dataloader: object +) -> object: if stage != RunningStage.TRAINING: is_shuffled = _is_dataloader_shuffled(dataloader) # limit this warning only for samplers assigned automatically when shuffle is set diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index abd1f563f019f..5a1ab0cffd142 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -410,8 +410,10 @@ def __init__( # init loops self.fit_loop = _FitLoop(self, min_epochs=min_epochs, max_epochs=max_epochs) self.fit_loop.epoch_loop = _TrainingEpochLoop(self, min_steps=min_steps, max_steps=max_steps) - self.validate_loop = _EvaluationLoop(self, inference_mode=inference_mode) - self.test_loop = _EvaluationLoop(self, inference_mode=inference_mode) + self.validate_loop = _EvaluationLoop( + self, TrainerFn.VALIDATING, RunningStage.VALIDATING, inference_mode=inference_mode + ) + self.test_loop = _EvaluationLoop(self, TrainerFn.TESTING, RunningStage.TESTING, inference_mode=inference_mode) self.predict_loop = _PredictionLoop(self, inference_mode=inference_mode) self.accumulate_grad_batches = accumulate_grad_batches @@ -1548,11 +1550,7 @@ def configure_optimizers(self): if self.train_dataloader is None: rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.") - state = self.state - self.state.fn = TrainerFn.FITTING - self.training = True self.fit_loop.setup_data() - self.state = state total_batches = self.num_training_batches diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 7586c5ef29c4d..ca2e3b3ce6e86 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -248,7 +248,6 @@ def test_ddp_strategy_set_timeout(mock_init_process_group): strategy=ddp_strategy, ) # test wrap the model if fitting - trainer.state.fn = TrainerFn.FITTING trainer.strategy.connect(model) trainer.lightning_module.trainer = trainer trainer.strategy.setup_environment() diff --git a/tests/tests_pytorch/strategies/test_single_device_strategy.py b/tests/tests_pytorch/strategies/test_single_device_strategy.py index 9759358700e69..0f05b1f160ef2 100644 --- a/tests/tests_pytorch/strategies/test_single_device_strategy.py +++ b/tests/tests_pytorch/strategies/test_single_device_strategy.py @@ -119,20 +119,12 @@ def test_process_dataloader_gets_called_as_expected(keyword, value, monkeypatch) monkeypatch.setattr(strategy, "process_dataloader", process_dataloader_mock) if "train" in keyword: - trainer.state.fn = "fit" - trainer.training = True fn = trainer.fit_loop.setup_data elif "val" in keyword: - trainer.state.fn = "validate" - trainer.validating = True fn = trainer.validate_loop.setup_data elif "test" in keyword: - trainer.state.fn = "test" - trainer.testing = True fn = trainer.test_loop.setup_data else: - trainer.state.fn = "predict" - trainer.predicting = True fn = trainer.predict_loop.setup_data trainer._data_connector.attach_dataloaders(model, **{keyword: value}) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 03ad099ec451f..3a76bc1bec56d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -370,8 +370,6 @@ def test_error_raised_with_float_limited_eval_batches(): trainer = Trainer(limit_val_batches=limit_val_batches) trainer.strategy.connect(model) trainer._data_connector.attach_data(model) - trainer.state.fn = TrainerFn.VALIDATING - trainer.state.stage = RunningStage.VALIDATING with pytest.raises( MisconfigurationException, match=rf"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`", @@ -408,8 +406,6 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl, wa trainer.strategy.connect(model) trainer._data_connector.attach_data(model, val_dataloaders=val_dl) context = pytest.warns if warns else no_warning_call - trainer.state.fn = TrainerFn.VALIDATING - trainer.state.stage = RunningStage.VALIDATING with context(PossibleUserWarning, match="recommended .* turn shuffling off for val/test"): trainer.validate_loop.setup_data() @@ -542,12 +538,10 @@ def test_eval_distributed_sampler_warning(devices, warn_context): trainer._data_connector.attach_data(model) trainer.state.fn = TrainerFn.VALIDATING - trainer.state.stage = RunningStage.VALIDATING with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.validate_loop.setup_data() trainer.state.fn = TrainerFn.TESTING - trainer.state.stage = RunningStage.TESTING with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.test_loop.setup_data() @@ -564,8 +558,6 @@ def val_dataloader(self): model = CustomModel() trainer.strategy.connect(model) trainer._data_connector.attach_data(model) - trainer.state.fn = TrainerFn.FITTING - trainer.state.stage = RunningStage.VALIDATING trainer.fit_loop.epoch_loop.val_loop.setup_data() assert trainer.val_dataloaders.sampler.shuffle == shuffle diff --git a/tests/tests_pytorch/trainer/flags/test_limit_batches.py b/tests/tests_pytorch/trainer/flags/test_limit_batches.py index 0ae0aac644203..481664fc0c2cc 100644 --- a/tests/tests_pytorch/trainer/flags/test_limit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_limit_batches.py @@ -17,7 +17,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.states import RunningStage +from lightning.pytorch.trainer.states import TrainerFn def test_num_dataloader_batches(tmpdir): @@ -46,15 +46,15 @@ def test_num_dataloader_batches(tmpdir): @pytest.mark.parametrize( - ["stage", "mode"], + "mode", [ - (RunningStage.VALIDATING, "val"), - (RunningStage.TESTING, "test"), - (RunningStage.PREDICTING, "predict"), + "val", + "test", + "predict", ], ) @pytest.mark.parametrize("limit_batches", [0.1, 10]) -def test_eval_limit_batches(stage, mode, limit_batches): +def test_eval_limit_batches(mode, limit_batches): limit_eval_batches = f"limit_{mode}_batches" dl_hook = f"{mode}_dataloader" model = BoringModel() @@ -65,16 +65,17 @@ def test_eval_limit_batches(stage, mode, limit_batches): trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model) - trainer.state.stage = stage - trainer.state.fn = stage.value - trainer._active_loop.setup_data() - if stage == RunningStage.VALIDATING: + if mode == "val": + trainer.validate_loop.setup_data() + trainer.state.fn = TrainerFn.VALIDATING loader_num_batches = trainer.num_val_batches dataloaders = trainer.val_dataloaders - elif stage == RunningStage.TESTING: + elif mode == "test": + trainer.test_loop.setup_data() loader_num_batches = trainer.num_test_batches dataloaders = trainer.test_dataloaders - elif stage == RunningStage.PREDICTING: + elif mode == "predict": + trainer.predict_loop.setup_data() loader_num_batches = trainer.num_predict_batches dataloaders = trainer.predict_dataloaders diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index efbb37a9cb150..a26e6b4ec9ce7 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -20,7 +20,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.trainer.states import RunningStage, TrainerFn +from lightning.pytorch.trainer.states import RunningStage from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import SklearnDataset from tests_pytorch.helpers.runif import RunIf @@ -94,20 +94,19 @@ def test_overfit_batch_limits_eval(stage, mode, overfit_batches): trainer.strategy.connect(model) trainer._data_connector.attach_datamodule(model, datamodule=dm) - trainer.state.stage = stage - trainer.state.fn = stage.value - trainer._active_loop.setup_data() - if stage == RunningStage.VALIDATING: + trainer.fit_loop.epoch_loop.val_loop.setup_data() assert ( trainer.num_val_batches[0] == overfit_batches if isinstance(overfit_batches, int) else len(dm.val_dataloader()) * overfit_batches ) elif stage == RunningStage.TESTING: + trainer.test_loop.setup_data() assert trainer.num_test_batches[0] == len(eval_loader) assert isinstance(trainer.test_dataloaders.sampler, SequentialSampler) elif stage == RunningStage.PREDICTING: + trainer.predict_loop.setup_data() assert trainer.num_predict_batches[0] == len(eval_loader) assert isinstance(trainer.predict_dataloaders.sampler, SequentialSampler) @@ -144,8 +143,6 @@ def train_dataloader(self): model.trainer = trainer trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model=model) - trainer.state.fn = TrainerFn.FITTING - trainer.training = True trainer.fit_loop.setup_data() expected_batches = ( int(overfit_batches * full_train_samples) if isinstance(overfit_batches, float) else overfit_batches @@ -170,8 +167,6 @@ def test_distributed_sampler_with_overfit_batches(): model.trainer = trainer trainer.strategy.connect(model) trainer._data_connector.attach_dataloaders(model) - trainer.state.fn = TrainerFn.FITTING - trainer.training = True trainer.fit_loop.setup_data() train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 045cc38c18132..0826e045a93a5 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -23,7 +23,6 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.strategies.ipu import IPUStrategy -from lightning.pytorch.trainer.states import TrainerFn from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf @@ -47,8 +46,6 @@ def test_num_stepping_batches_raises_info_with_no_dataloaders_loaded(caplog): trainer.strategy.connect(model) # artificially setup the data - trainer.state.fn = TrainerFn.FITTING - trainer.training = True trainer.fit_loop.setup_data() with caplog.at_level(logging.INFO): diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index e0016ebf11477..ea3eb8a3c2847 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -526,14 +526,11 @@ def test_warning_on_zero_len_dataloader(): val_dataloader = DataLoader(RandomDataset(32, 0)) trainer._data_connector.attach_data(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) - trainer.state.fn = "fit" - trainer.training = True with pytest.warns(UserWarning, match="Total length of `CombinedLoader` across ranks is zero"): trainer.fit_loop.setup_data() assert trainer.num_training_batches == 0 trainer.state.fn = "validate" - trainer.validating = True with pytest.warns(UserWarning, match="Total length of `DataLoader` across ranks is zero"): trainer.validate_loop.setup_data() assert trainer.num_val_batches == [0] diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 31e0938921d5f..2e079dc672fce 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1961,16 +1961,12 @@ def test_dataloaders_are_not_loaded_if_disabled_through_limit_batches(running_st trainer.state.stage = running_stage if running_stage == "train": - trainer.state.fn = "fit" fn = trainer.fit_loop.setup_data elif running_stage == "validate": - trainer.state.fn = "validate" fn = trainer.validate_loop.setup_data elif running_stage == "test": - trainer.state.fn = "test" fn = trainer.test_loop.setup_data else: - trainer.state.fn = "predict" fn = trainer.predict_loop.setup_data # with no limit, the attribute is None diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index e2297456eeb6f..26f7e5359b882 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -26,7 +26,6 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.combined_loader import ( _LITERAL_SUPPORTED_MODES, _MaxSize, @@ -401,8 +400,6 @@ def __init__(self, data_source, name) -> None: trainer = Trainer(use_distributed_sampler=use_distributed_sampler, strategy="ddp", accelerator="cpu", devices=2) trainer.strategy.connect(model) trainer._data_connector.attach_data(model, train_dataloaders=combined_loader) - trainer.state.fn = "fit" - trainer.state.stage = RunningStage.TRAINING trainer.fit_loop.setup_data() samplers_flattened = tree_flatten(combined_loader.sampler)[0] @@ -431,8 +428,6 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(monkeypatch, accelerat {"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)}, ) trainer.strategy.connect(model) - trainer.state.fn = "fit" - trainer.state.stage = RunningStage.TRAINING trainer._data_connector.attach_data(model, train_dataloaders=combined_loader) trainer.fit_loop.setup_data() @@ -526,8 +521,6 @@ def test_combined_dataloader_for_training_with_ddp(use_distributed_sampler, mode if use_distributed_sampler else expected_length_before_ddp ) - trainer.state.fn = "fit" - trainer.state.stage = RunningStage.TRAINING trainer.fit_loop.setup_data() assert trainer.train_dataloader is not None assert isinstance(trainer.fit_loop._combined_loader, CombinedLoader)