Skip to content

Commit

Permalink
Avoid inference_mode with torch.compile (Lightning-AI#17215)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 29, 2023
1 parent 713bb32 commit b97b3ac
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Pickling the `LightningModule` no longer pickles the `Trainer` ([#17133](https://github.com/Lightning-AI/lightning/pull/17133))


- Disable `torch.inference_mode` with `torch.compile` in PyTorch 2.0 ([#17215](https://github.com/Lightning-AI/lightning/pull/17215))

### Depercated

-
Expand Down
5 changes: 4 additions & 1 deletion src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators import TPUAccelerator
from lightning.pytorch.callbacks.timer import Timer
Expand Down Expand Up @@ -166,6 +166,9 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any:
elif _TORCH_GREATER_EQUAL_1_13 and isinstance(self.trainer.strategy, FSDPStrategy):
# https://github.com/pytorch/pytorch/issues/95957
context_manager = torch.no_grad
elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None:
# avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1
context_manager = torch.no_grad
elif self.inference_mode:
context_manager = torch.inference_mode
else:
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def restore_env_variables():
"KMP_INIT_AT_FORK", # leaked since PyTorch 1.13
"KMP_DUPLICATE_LIB_OK", # leaked since PyTorch 1.13
"CRC32C_SW_MODE", # leaked by tensorboardX
"TRITON_CACHE_DIR", # leaked by torch.compile
# leaked by XLA
"ALLOW_MULTIPLE_LIBTPU_LOAD",
"GRPC_VERBOSITY",
Expand Down
4 changes: 3 additions & 1 deletion tests/tests_pytorch/trainer/flags/test_inference_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import torch

from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops import _Loop
Expand Down Expand Up @@ -86,4 +87,5 @@ def run(self):
f.inference_mode = True
with mock.patch("torch.inference_mode") as inference_mode_mock:
f.run()
inference_mode_mock.assert_called_once_with()
if not _TORCH_EQUAL_2_0:
inference_mode_mock.assert_called_once_with()
18 changes: 18 additions & 0 deletions tests/tests_pytorch/utilities/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,21 @@ def training_step(self, batch, batch_idx):
trainer.fit(compiled_model)

assert set(trainer.callback_metrics) == {"loss"}


@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@RunIf(min_torch="2.0.0")
def test_trainer_compiled_model_test(tmp_path):
skip_if_unsupported()

model = BoringModel()
compiled_model = torch.compile(model)

trainer = Trainer(
default_root_dir=tmp_path,
fast_dev_run=True,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer.test(compiled_model)

0 comments on commit b97b3ac

Please sign in to comment.