Skip to content

Commit

Permalink
Never pickle the Trainer with the LightningModule (Lightning-AI#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 20, 2023
1 parent c575efb commit d17d4f4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))


### Depercated
Expand All @@ -29,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-



## [2.0.0] - 2023-03-15
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
raise MisconfigurationException("SWA does not currently support sharded models.")

# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if len(trainer.optimizers) != 1:
Expand Down
14 changes: 1 addition & 13 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._automatic_optimization: bool = True
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
self._register_sharded_tensor_state_dict_hooks_if_available()
self._compiler_ctx: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -1539,20 +1538,9 @@ def load_from_checkpoint(
)
return cast(Self, loaded)

@contextmanager
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
self._should_prevent_trainer_and_dataloaders_deepcopy = True
yield
self._should_prevent_trainer_and_dataloaders_deepcopy = False

def __getstate__(self) -> Dict[str, Any]:
state = dict(self.__dict__)
if self._should_prevent_trainer_and_dataloaders_deepcopy:
state["_trainer"] = None
state.pop("train_dataloader", None)
state.pop("val_dataloader", None)
state.pop("test_dataloader", None)
state.pop("predict_dataloader", None)
state["_trainer"] = None
return state

def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
Expand Down

0 comments on commit d17d4f4

Please sign in to comment.