From bdefa260b2831977b4a458d9daef2b710330c78c Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Fri, 14 Oct 2022 20:45:25 +0000 Subject: [PATCH] [RFC] Separate CPU offload activation to its own wrapper (#85459) Passing in `offload_to_cpu=True` to checkpoint_wrapper is a bit confusing, because this causes the activation checkpoint args to be ignored and we do CPU offloading. This isn't ideal from API design perspective, so proposing to make `offload_wrapper` its own concept. Now, offload to CPU + checkpoint can be composed together, such as ``` # apply AC to transformer layers apply_ac_wrapper(model, checkpoint_wrapper, check_fn=lambda mod: isinstance(mod, TransformerLayer)) # offload the rest of activations to CPU model = offload_wrapper(model) ``` Will polish / add tests if this proposal sounds good. Differential Revision: [D39719854](https://our.internmc.facebook.com/intern/diff/D39719854/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/85459 Approved by: https://github.com/awgu --- .../fsdp/test_checkpoint_wrapper.py | 25 ++- test/distributed/fsdp/test_fsdp_checkpoint.py | 31 ++- .../_checkpoint/checkpoint_wrapper.py | 193 ++++++++++-------- 3 files changed, 151 insertions(+), 98 deletions(-) diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 329f586369157d..8bd2b74695d3bf 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -7,8 +7,10 @@ import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, + offload_wrapper, apply_activation_checkpointing, CheckpointWrapper, + OffloadWrapper, CheckpointImpl ) @@ -21,6 +23,9 @@ import unittest +_SAVED_PREFIX = '_saved_' +GRAD_FN_NEXT_FUNCTIONS = 'next_functions' + class CheckpointWrapperTest(TestCase): def setUp(self): super().setUp() @@ -72,11 +77,14 @@ def forward(self, a, b, c=None, d=None, **kwargs): for wrapper in [ partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT), - partial(checkpoint_wrapper, offload_to_cpu=True), + offload_wrapper, ]: with self.subTest(wrapper=wrapper): model = wrapper(MyModel()) - self.assertTrue(isinstance(model, CheckpointWrapper)) + if wrapper == offload_wrapper: + self.assertTrue(isinstance(model, OffloadWrapper)) + else: + self.assertTrue(isinstance(model, CheckpointWrapper)) # Verify kwargs can be passed in inp = torch.ones(4, 10, requires_grad=True) out = model(inp, inp, c=inp, d=inp, e=inp, f=inp) @@ -211,6 +219,7 @@ def check_fn(l): for wrapper in [ partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT), partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT), + offload_wrapper, ]: model = MyModel() if n_linear is None: @@ -223,12 +232,12 @@ def check_fn(l): model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn ) n_linear_wrapped = sum(1 if isinstance(x, nn.Linear) else 0 for x in model.modules()) - n_checkpointed = sum(1 if isinstance(x, CheckpointWrapper) else 0 for x in model.modules()) + n_checkpointed = sum(1 if isinstance(x, (CheckpointWrapper, OffloadWrapper)) else 0 for x in model.modules()) self.assertEqual(n_checkpointed, n_linear_wrapped) self.assertEqual(n_linear, n_linear_wrapped) for j in range(3): - self.assertTrue(isinstance(model.seq[j].lin, CheckpointWrapper)) - self.assertTrue(isinstance(model.seq[j].nested_linear[0], CheckpointWrapper)) + self.assertTrue(isinstance(model.seq[j].lin, (CheckpointWrapper, OffloadWrapper))) + self.assertTrue(isinstance(model.seq[j].nested_linear[0], (CheckpointWrapper, OffloadWrapper))) inp = torch.randn(4, 10, requires_grad=True) for i in range(6): @@ -276,7 +285,7 @@ def testing_cpu_offload_unpack_hook(packed): orig_init = torch.autograd.graph.saved_tensors_hooks.__init__ torch.autograd.graph.saved_tensors_hooks.__init__ = patched_init - model = checkpoint_wrapper(model, offload_to_cpu=True) + model = offload_wrapper(model) inp = torch.randn(3, 10, device='cuda') loss = model(inp).sum() @@ -286,7 +295,7 @@ def testing_cpu_offload_unpack_hook(packed): def dfs(grad_fn): for e in dir(grad_fn): - if not e.startswith('_saved_'): + if not e.startswith(_SAVED_PREFIX): continue saved = getattr(grad_fn, e) @@ -295,7 +304,7 @@ def dfs(grad_fn): nonlocal offload_verified offload_verified = True - if hasattr(grad_fn, 'next_functions'): + if hasattr(grad_fn, GRAD_FN_NEXT_FUNCTIONS): for next_grad_fn, _ in grad_fn.next_functions: dfs(next_grad_fn) diff --git a/test/distributed/fsdp/test_fsdp_checkpoint.py b/test/distributed/fsdp/test_fsdp_checkpoint.py index ea7ecc5b308993..14456df92f84f8 100644 --- a/test/distributed/fsdp/test_fsdp_checkpoint.py +++ b/test/distributed/fsdp/test_fsdp_checkpoint.py @@ -5,6 +5,7 @@ from functools import partial import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullyShardedDataParallel as FSDP, @@ -12,6 +13,7 @@ ) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, + offload_wrapper, ) from torch.testing._internal.common_distributed import ( skip_if_lt_x_gpu, @@ -65,9 +67,10 @@ def __init__( l3 = nn.Linear(3, 3).cuda() if checkpoint_layer: - ckpt_wrapper = partial( - checkpoint_wrapper, offload_to_cpu=offload_activations - ) + if offload_activations: + ckpt_wrapper = offload_wrapper + else: + ckpt_wrapper = checkpoint_wrapper l1 = ckpt_wrapper(l1) l2 = ckpt_wrapper(l2) @@ -110,11 +113,15 @@ def _verify_parity(self, losses, outputs, models): @parametrize("offload_activations", [True, False]) def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): # Test checkpoint(FSDP(layer1), FSDP(layer2), ....) - ckpt_sequential_wrapped_fsdp = checkpoint_wrapper( + if offload_activations: + wrapper_to_use = offload_wrapper + else: + wrapper_to_use = checkpoint_wrapper + + ckpt_sequential_wrapped_fsdp = wrapper_to_use( TestFSDPCheckpoint.SequentialModule( wrap_fsdp=True, cpu_offload=cpu_offload ), - offload_to_cpu=offload_activations, ) # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), .... inner_ckpt = TestFSDPCheckpoint.SequentialModule( @@ -153,6 +160,8 @@ def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): self._verify_parity(losses, outputs, models) + dist.barrier() + @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", @@ -166,13 +175,17 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): # Runs FSDP with no checkpointing fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # Runs checkpoint-wrapped FSDP - checkpointed_fsdp = checkpoint_wrapper( + if offload_activations: + wrapper_to_use = offload_wrapper + else: + wrapper_to_use = checkpoint_wrapper + + checkpointed_fsdp = wrapper_to_use( FSDP(deepcopy(seq), cpu_offload=cpu_offload), - offload_to_cpu=offload_activations, ) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP( - checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations), + wrapper_to_use(deepcopy(seq)), cpu_offload=cpu_offload, ) # Runs FSDP with manual calls to checkpoint. @@ -220,6 +233,8 @@ def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): self._verify_parity(losses, outputs, models) + dist.barrier() + instantiate_parametrized_tests(TestFSDPCheckpoint) if __name__ == "__main__": diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index ad1bb3f7b6241a..30c8cb4e6beb81 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.autograd.graph import save_on_cpu from torch.distributed.utils import _pack_kwargs, _replace_by_prefix, _unpack_kwargs -from torch.utils.checkpoint import checkpoint +from torch.utils.checkpoint import checkpoint as torch_utils_checkpoint _CHECKPOINT_PREFIX = "_checkpoint_wrapped_module" @@ -15,42 +15,14 @@ class CheckpointImpl(Enum): NO_REENTRANT = auto() -class CheckpointWrapper(torch.nn.Module): +class ActivationWrapper(torch.nn.Module): """ - An nn.Module that wraps another nn.Module with checkpointing. Note that this - module is not meant to be used directly, but instead it is to be used - through the ``checkpoint_wrapper`` function. + Base class for Activation Checkpoint and Activation Offload. + Not meant to be instantiated directly. """ - def __init__( - self, - mod: torch.nn.Module, - checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, - offload_to_cpu: bool = False, - checkpoint_fn=None, - *checkpoint_fn_args, - **checkpoint_fn_kwargs, - ): + def __init__(self, mod): super().__init__() self._checkpoint_wrapped_module = mod - self.checkpoint_impl = checkpoint_impl - self.offload_to_cpu = offload_to_cpu - if self.offload_to_cpu: - self.checkpoint_fn = None - else: - if checkpoint_fn is None: - # use torch.utils.checkpoint - self.checkpoint_fn = partial( - checkpoint, - use_reentrant=( - self.checkpoint_impl == CheckpointImpl.REENTRANT - ), - ) - else: - self.checkpoint_fn = partial( - checkpoint_fn, - *checkpoint_fn_args, - **checkpoint_fn_kwargs, - ) # state_dict post hook to remove prefix to allow loading into a # non-checkpoint wrapped module. self._register_state_dict_hook(self._post_state_dict_hook) @@ -60,6 +32,9 @@ def __init__( self._pre_load_state_dict_hook, with_module=True ) + def forward(self, *args, **kwargs): + raise ValueError("Subclasses should implement forward().") + def __getattr__(self, name: str) -> Any: """Forward missing attributes to wrapped module.""" try: @@ -71,44 +46,6 @@ def __getitem__(self, key: int) -> Any: """Forward indexing calls in case the module is a nn.Sequential.""" return self._checkpoint_wrapped_module.__getitem__(key) # type: ignore[operator] - def forward(self, *args, **kwargs): - if self.offload_to_cpu: - with save_on_cpu(pin_memory=True): - return self._checkpoint_wrapped_module(*args, **kwargs) - else: - # Support keyword arguments for reentrant checkpoint. Note that this - # only works if user has specified self.checkpoint_impl and is not - # using their own custom checkpoint_fn. - if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: - # Pack the args and kwargs - flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) - - # Function that only takes (packed) args, but can unpack them - # into the original args and kwargs for the checkpointed - # function, and runs that function. - def my_function(*inputs): - # unpack back into args and kwargs - unpacked_args, unpacked_kwargs = _unpack_kwargs( - inputs, kwarg_keys - ) - # run original module - return self._checkpoint_wrapped_module( - *unpacked_args, **unpacked_kwargs - ) - - # Pass the function that only takes packed args into reentrant - # checkpoint API. - return self.checkpoint_fn( # type: ignore[misc] - my_function, - *flat_args, - ) - else: - return self.checkpoint_fn( # type: ignore[misc] - self._checkpoint_wrapped_module, - *args, - **kwargs - ) - def named_parameters( self, *args, @@ -155,10 +92,107 @@ def _pre_load_state_dict_hook( _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}.") +class OffloadWrapper(ActivationWrapper): + def __init__(self, mod): + super().__init__(mod) + + def forward(self, *args, **kwargs): + with save_on_cpu(pin_memory=True): + return self._checkpoint_wrapped_module(*args, **kwargs) + + +class CheckpointWrapper(ActivationWrapper): + """ + An ``nn.Module`` that wraps another ``nn.Module`` with checkpointing. Note that this + module is not meant to be used directly, but instead it is to be used + through the ``checkpoint_wrapper`` function. + """ + def __init__( + self, + mod: torch.nn.Module, + checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, + checkpoint_fn=None, + *checkpoint_fn_args, + **checkpoint_fn_kwargs, + ): + super().__init__(mod) + self.checkpoint_impl = checkpoint_impl + if checkpoint_fn is None: + # use torch.utils.checkpoint + self.checkpoint_fn = partial( + torch_utils_checkpoint, + use_reentrant=( + self.checkpoint_impl == CheckpointImpl.REENTRANT + ), + ) + else: + # Construct user-specified checkpoint function. + self.checkpoint_fn = partial( + checkpoint_fn, + *checkpoint_fn_args, + **checkpoint_fn_kwargs, + ) + + def forward(self, *args, **kwargs): + # Support keyword arguments for reentrant checkpoint. Note that this + # only works if user has specified self.checkpoint_impl and is not + # using their own custom checkpoint_fn. + if self.checkpoint_impl == CheckpointImpl.REENTRANT and kwargs != {}: + # Pack the args and kwargs + flat_args, kwarg_keys = _pack_kwargs(*args, **kwargs) + + # Function that only takes (packed) args, but can unpack them + # into the original args and kwargs for the checkpointed + # function, and runs that function. + def my_function(*inputs): + # unpack back into args and kwargs + unpacked_args, unpacked_kwargs = _unpack_kwargs( + inputs, kwarg_keys + ) + # run original module + return self._checkpoint_wrapped_module( + *unpacked_args, **unpacked_kwargs + ) + + # Pass the function that only takes packed args into reentrant + # checkpoint API. + return self.checkpoint_fn( # type: ignore[misc] + my_function, + *flat_args, + ) + else: + return self.checkpoint_fn( # type: ignore[misc] + self._checkpoint_wrapped_module, + *args, + **kwargs + ) + +def offload_wrapper( + module: torch.nn.Module +) -> torch.nn.Module: + """ + A convenience wrapper for activation offloading to CPU. If the module is wrapped + with this function, all subsequent calls to the module will automatically + offload intermediate activations to the CPU. Wrappers with activation + offload can be composed with ones that do recomputation-based + checkpoint to trade off increased compute versus increased CPU + memory usage and additional H2D transfers. + Usage:: + offloaded_module = offload_wrapper(module) + outputs = checkpointed_module(inputs) + Args: + module (nn.Module): + The module to be wrapped + Returns: + (nn.Module): + Wrapped module + """ + return OffloadWrapper(module) + + def checkpoint_wrapper( module: torch.nn.Module, checkpoint_impl: CheckpointImpl = CheckpointImpl.REENTRANT, - offload_to_cpu: bool = False, checkpoint_fn=None, *checkpoint_fn_args, **checkpoint_fn_kwargs, @@ -181,14 +215,6 @@ def checkpoint_wrapper( specified. Note that for implementations using reentrant checkpoint from ``torch.utils.checkpoint``, keyword arguments will only be supported if ``checkpoint_impl`` is passed as ``CheckpointImpl.REENTRANT`. - offload_to_cpu (Optional[bool]): - Whether to offload activations of this wrapped module to CPU. Note - that if this is specified, ``checkpoint_impl`` and ``checkpoint_fn`` - arguments will be ignored in favor of the activations being - offloaded to CPU. Default is ``False``. Wrappers with activation - offload can be composed with ones that do recomputation-based - checkpoint to trade off increased compute versus increased CPU - memory usage and additional H2D transfers. checkpoint_fn (Optional[Callable]): Functional checkpoint implementation to use. If this is specified, it will be used over the default ``torch.utils.checkpoint.checkpoint`` @@ -202,7 +228,7 @@ def checkpoint_wrapper( """ return CheckpointWrapper( - module, checkpoint_impl, offload_to_cpu, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs + module, checkpoint_impl, checkpoint_fn, checkpoint_fn_args, checkpoint_fn_kwargs ) @@ -219,13 +245,16 @@ def apply_activation_checkpointing( their checkpoint-wrapped modules. Note:: This function will not wrap the overall root module. If this is needed, please directly use - :class:`CheckpointWrapper`. + :func:`checkpoint_wrapper` or :func:`offload_wrapper`. Usage:: model = nn.Sequential( nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10) ) check_fn = lambda l: isinstance(l, nn.Linear) + # checkpoint activations apply_activation_checkpointing(model, checkpoint_wrapper_fn=checkpoint_wrapper, check_fn=check_fn) + # Or offload activations to CPU + apply_activation_checkpointing(model, checkpoint_wrapper_fn=offload_wrapper, check_fn=check_fn) Args: model (nn.Module): The model whose submodules should be wrapped with activation checkpointing.