Skip to content

Commit

Permalink
[RFC] Separate CPU offload activation to its own wrapper (pytorch#85459)
Browse files Browse the repository at this point in the history
Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept.

Now, offload to CPU + checkpoint can be composed together, such as

```
# apply AC to transformer layers
apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer))
# offload the rest of activations to CPU
model = offload_wrapper(model)
```

Will polish / add tests if this proposal sounds good.

Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/)
Pull Request resolved: pytorch#85459
Approved by: https://github.com/awgu
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Oct 15, 2022
1 parent 100113b commit bdefa26
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 98 deletions.
25 changes: 17 additions & 8 deletions test/distributed/fsdp/test_checkpoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
offload_wrapper,
apply_activation_checkpointing,
CheckpointWrapper,
OffloadWrapper,
CheckpointImpl
)

Expand All @@ -21,6 +23,9 @@

import unittest

_SAVED_PREFIX = '_saved_'
GRAD_FN_NEXT_FUNCTIONS = 'next_functions'

class CheckpointWrapperTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -72,11 +77,14 @@ def forward(self, a, b, c=None, d=None, **kwargs):
for wrapper in [
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
partial(checkpoint_wrapper, offload_to_cpu=True),
offload_wrapper,
]:
with self.subTest(wrapper=wrapper):
model = wrapper(MyModel())
self.assertTrue(isinstance(model, CheckpointWrapper))
if wrapper == offload_wrapper:
self.assertTrue(isinstance(model, OffloadWrapper))
else:
self.assertTrue(isinstance(model, CheckpointWrapper))
# Verify kwargs can be passed in
inp = torch.ones(4, 10, requires_grad=True)
out = model(inp, inp, c=inp, d=inp, e=inp, f=inp)
Expand Down Expand Up @@ -211,6 +219,7 @@ def check_fn(l):
for wrapper in [
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT),
partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT),
offload_wrapper,
]:
model = MyModel()
if n_linear is None:
Expand All @@ -223,12 +232,12 @@ def check_fn(l):
model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn
)
n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules())
n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules())
n_checkpointed = sum(1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 for x in model.modules())
self.assertEqual(n_checkpointed, n_linear_wrapped)
self.assertEqual(n_linear, n_linear_wrapped)
for j in range(3):
self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper))
self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper))
self.assertTrue(isinstance(model.seq[j].lin, (CheckpointWrapper, OffloadWrapper)))
self.assertTrue(isinstance(model.seq[j].nested_linear[0], (CheckpointWrapper, OffloadWrapper)))

inp = torch.randn(4, 10, requires_grad=True)
for i in range(6):
Expand Down Expand Up @@ -276,7 +285,7 @@ def testing_cpu_offload_unpack_hook(packed):
orig_init = torch.autograd.graph.saved_tensors_hooks.__init__
torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init

model = checkpoint_wrapper(model, offload_to_cpu=True)
model = offload_wrapper(model)

inp = torch.randn(3, 10, device='cuda')
loss = model(inp).sum()
Expand All @@ -286,7 +295,7 @@ def testing_cpu_offload_unpack_hook(packed):

def dfs(grad_fn):
for e in dir(grad_fn):
if not e.startswith('_saved_'):
if not e.startswith(_SAVED_PREFIX):
continue

saved = getattr(grad_fn, e)
Expand All @@ -295,7 +304,7 @@ def dfs(grad_fn):
nonlocal offload_verified
offload_verified = True

if hasattr(grad_fn, 'next_functions'):
if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS):
for next_grad_fn, _ in grad_fn.next_functions:
dfs(next_grad_fn)

Expand Down
31 changes: 23 additions & 8 deletions test/distributed/fsdp/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
offload_wrapper,
)
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
Expand Down Expand Up @@ -65,9 +67,10 @@ def __init__(
l3 = nn.Linear(3, 3).cuda()

if checkpoint_layer:
ckpt_wrapper = partial(
checkpoint_wrapper, offload_to_cpu=offload_activations
)
if offload_activations:
ckpt_wrapper = offload_wrapper
else:
ckpt_wrapper = checkpoint_wrapper

l1 = ckpt_wrapper(l1)
l2 = ckpt_wrapper(l2)
Expand Down Expand Up @@ -110,11 +113,15 @@ def _verify_parity(self, losses, outputs, models):
@parametrize("offload_activations", [True, False])
def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper

ckpt_sequential_wrapped_fsdp = wrapper_to_use(
TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True, cpu_offload=cpu_offload
),
offload_to_cpu=offload_activations,
)
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
Expand Down Expand Up @@ -153,6 +160,8 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):

self._verify_parity(losses, outputs, models)

dist.barrier()

@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
Expand All @@ -166,13 +175,17 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
# Runs FSDP with no checkpointing
fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
# Runs checkpoint-wrapped FSDP
checkpointed_fsdp = checkpoint_wrapper(
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper

checkpointed_fsdp = wrapper_to_use(
FSDP(deepcopy(seq), cpu_offload=cpu_offload),
offload_to_cpu=offload_activations,
)
# Runs FSDP-wrapped checkpointed module
fsdp_wrapped_checkpoint = FSDP(
checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
wrapper_to_use(deepcopy(seq)),
cpu_offload=cpu_offload,
)
# Runs FSDP with manual calls to checkpoint.
Expand Down Expand Up @@ -220,6 +233,8 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):

self._verify_parity(losses, outputs, models)

dist.barrier()

instantiate_parametrized_tests(TestFSDPCheckpoint)

if __name__ == "__main__":
Expand Down
Loading

0 comments on commit bdefa26

Please sign in to comment.