diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index d590a28bdf077..2cea008313568 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133)) ### Depercated @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- + ## [2.0.0] - 2023-03-15 diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index a4f58efd45dc8..11d78dc49ef26 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -150,8 +150,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s raise MisconfigurationException("SWA does not currently support sharded models.") # copy the model before moving it to accelerator device. - with pl_module._prevent_trainer_and_dataloaders_deepcopy(): - self._average_model = deepcopy(pl_module) + self._average_model = deepcopy(pl_module) def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if len(trainer.optimizers) != 1: diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index bf596203ab4db..feef9eb87028b 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -124,7 +124,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._param_requires_grad_state: Dict[str, bool] = {} self._metric_attributes: Optional[Dict[int, str]] = None - self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False self._register_sharded_tensor_state_dict_hooks_if_available() self._compiler_ctx: Optional[Dict[str, Any]] = None @@ -1539,20 +1538,9 @@ def load_from_checkpoint( ) return cast(Self, loaded) - @contextmanager - def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]: - self._should_prevent_trainer_and_dataloaders_deepcopy = True - yield - self._should_prevent_trainer_and_dataloaders_deepcopy = False - def __getstate__(self) -> Dict[str, Any]: state = dict(self.__dict__) - if self._should_prevent_trainer_and_dataloaders_deepcopy: - state["_trainer"] = None - state.pop("train_dataloader", None) - state.pop("val_dataloader", None) - state.pop("test_dataloader", None) - state.pop("predict_dataloader", None) + state["_trainer"] = None return state def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: