From 7a8b691388d31dc86ff188bf339181dd4b5dcc13 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 22 Mar 2023 12:49:26 -0400 Subject: [PATCH] Make early stop the default for checkpoint and expose a way to disable (#96866) Why did I choose context manager instead of per-call? Early stopping is not part of the model definition, and depending on how a particular model is used, e.g., with PT2 or not we may or may not want to disable early stopping. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96866 Approved by: https://github.com/albanD --- test/test_autograd.py | 118 ++++++++++++++++++++++++++++++++++---- torch/utils/checkpoint.py | 49 ++++++++++++---- 2 files changed, 146 insertions(+), 21 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index b1367ad4cbc44..538660e7e97e0 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5606,18 +5606,18 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): def context_fn(): return verbose_mode, contextlib.nullcontext() - out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn) - self.assertEqual(verbose_mode.operators, ['sin.default']) + out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn) + self.assertEqual(verbose_mode.operators, ['exp.default']) verbose_mode.operators = [] def context_fn(): return contextlib.nullcontext(), verbose_mode - out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn) + out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn) out.backward() self.assertEqual( verbose_mode.operators, - ['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default'] + ['exp.default', 'detach.default', 'detach.default'] ) with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"): @@ -10677,7 +10677,7 @@ def scope(): @parametrize("early_stop", [True, False]) def test_nested_checkpoint(self, early_stop): - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): x = torch.randn((), requires_grad=True) def f(x): @@ -10703,7 +10703,7 @@ def g(x): @parametrize("early_stop", [True, False]) def test_nested_checkpoint_two_children(self, early_stop): - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): grad, sum, c = self.grad, self.sum, self.checkpoint def f(x): @@ -10742,7 +10742,7 @@ def fn(k, a, b, f): def f(x): return x.sin() - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False) actual_grads = torch.autograd.grad(out, (a, b)) @@ -10762,7 +10762,7 @@ def fn(a, blah=None): a = torch.tensor(2., requires_grad=True) b = torch.tensor(3., requires_grad=True) - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): out = checkpoint(fn, a, blah=b, use_reentrant=False) actual_grads = torch.autograd.grad(out, (a, b)) @@ -10783,7 +10783,7 @@ def fn(a): a = torch.tensor(1., requires_grad=True) - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): out = checkpoint(fn, a, use_reentrant=False) # The hook is registered on the original graph out.grad_fn.next_functions[0][0].register_hook(hook) @@ -10805,11 +10805,109 @@ def hook(*_unused_args): x.backward(retain_graph=True) a = torch.tensor(1., requires_grad=True) - with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop): + with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): x, out = checkpoint(fn, a, use_reentrant=False) out.grad_fn.register_hook(hook) out.backward(retain_graph=True) + def test_nested_checkpoint_set_early_stop(self): + counter = [0] + + def clone(x): + counter[0] += 1 + return x.clone() + + def fn(x): + # Since clone does not save anything, it is not recomputed iff + # early stop is enabled. + return clone(x.sin().cos()) + + # Early stopping is enabled by default + a = torch.tensor(1., requires_grad=True) + out = checkpoint(fn, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 1) + + # Try using the context manager to set early stopping to False. + # Expect early stopping to be disabled for all checkpoints ran under + # the context manager, even though context manager is no longer active + # when backward/recomputation is performed. + counter = [0] + a = torch.tensor(1., requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = checkpoint(fn, a, use_reentrant=False) + + out.backward() + self.assertEqual(counter[0], 2) + + def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): + # Case 1: We have one tensor saved and its the input + + # We have two different counters here because in this case we actually + # do call into x.sin() at the python level during recomputation whether + # or not early stop is enabled. This is because the early stopping + # only happens at the autograd level (preventing us from reaching the + # backend). + python_dispatch_counter = [0] + counter = [0] + + class SinCounterMode(TorchDispatchMode): + def __init__(self): + self.count = 0 + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + if func is torch.ops.aten.sin.default: + self.count += 1 + return func(*args, **kwargs) + + def fn(x): + counter[0] += 1 + return x.sin() + + # With early stopping (enabled by default) + a = torch.tensor(1., requires_grad=True) + with SinCounterMode() as python_dispatch_counter: + out = checkpoint(fn, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 2) + self.assertEqual(python_dispatch_counter.count, 1) + + # Without early stopping + counter = [0] + a = torch.tensor(1., requires_grad=True) + with SinCounterMode() as python_dispatch_counter: + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = checkpoint(fn, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 2) + self.assertEqual(python_dispatch_counter.count, 2) + + # Case 2: Forward saves no tensors + + # Since unpack isn't even called, counter is 1 whether or not early stop + # is enabled! + counter = [0] + + def fn2(x): + counter[0] += 1 + return x.clone() + + # With early stopping (enabled by default) + a = torch.tensor(1., requires_grad=True) + out = checkpoint(fn2, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 1) + + # Without early stopping + counter = [0] + a = torch.tensor(1., requires_grad=True) + with torch.utils.checkpoint.set_checkpoint_early_stop(False): + out = checkpoint(fn2, a, use_reentrant=False) + out.backward() + self.assertEqual(counter[0], 1) + + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): t = torch.randn(3, 3, device=device, requires_grad=True) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 96b671a10d290..7cf8c2e253a24 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -10,7 +10,7 @@ __all__ = [ "checkpoint", "checkpoint_sequential", "CheckpointFunction", "check_backward_validity", "detach_variable", "get_device_states", - "set_device_states", "noop_context_fn" + "set_device_states", "noop_context_fn", "set_checkpoint_early_stop" ] def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: @@ -208,7 +208,13 @@ def checkpoint( the non-reentrant variant of checkpoint (``use_reentrant=False``) differ in the following ways: - * The reentrant variant does not record the autograd graph during the + * Non-reentrant checkpoint stops recomputation as soon as all needed + intermediate activations have been recomputed. This feature is enabled + by default, but can be disabled with :func:`set_checkpoint_early_stop`. + Reentrant checkpoint always recomputes :attr:`function` in its + entirety during the backward pass. + + * The reentrant variant does not record the autograd graph during the forward pass, as it runs with the forward pass under :func:`torch.no_grad`. The non-reentrant version does record the autograd graph, allowing one to perform backward on the graph within @@ -492,13 +498,31 @@ def forward(input): # We save x and w, however. # 7. Continue with returning -# NB: This is temporary and should be removed in a follow up PR. Early stopping -# is currently disabled by default. Since some nested test cases require -# ealry stopping to pass, _set_checkpoint_early_stop can be used to enable. -_enable_checkpoint_early_stop = False +_enable_checkpoint_early_stop = True @contextlib.contextmanager -def _set_checkpoint_early_stop(enable): +def set_checkpoint_early_stop(enable: bool): + """Context manager that sets whether checkpoint should stop recomputation + early. + + By default, non-reentrant checkpoint stops recomputation as soon as it + has computed all needed Tensors. This context manager can be used to disable + that feature if it is problematic for your specific application. + + This context manager only needs to be active when forward is run. It does + not need to be active during backward. + + Example:: + + >>> # xdoctest: +SKIP(failing) + >>> message = "saved tensors default hooks are disabled" + >>> with set_checkpoint_early_stop(False): + ... # Any checkpoint under this context manager will respect this + ... # context manager, even if its backward is performed outside. + ... out = checkpoint(fn, inputs) + ... + >>> out.backward() + """ global _enable_checkpoint_early_stop try: prev = _enable_checkpoint_early_stop @@ -546,7 +570,7 @@ def backward(ctx, *grad_outputs): raise AssertionError("Did not expect to backward on this graph") class _CheckpointFrame(): - def __init__(self, recompute_fn): + def __init__(self, recompute_fn, early_stop): self.recompute_fn = recompute_fn self.input_saver = None self.weak_holders: List[ReferenceType] = [] @@ -560,6 +584,9 @@ def __init__(self, recompute_fn): self.recomp_counter: DefaultDict[int, int] = defaultdict(int) self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) + # See Rule 5 + self.early_stop = early_stop + # See Rule 5 class _StopRecomputationError(Exception): pass @@ -584,7 +611,7 @@ def pack_hook(x): holder.handles[gid] = _Handle() target_frame.recomputed[gid][holder.handles[gid]] = x.detach() - if _enable_checkpoint_early_stop and \ + if target_frame.early_stop and \ target_frame.recomp_counter[gid] == len(target_frame.weak_holders): raise _StopRecomputationError() # See Rule 6: [ Basic case ] above @@ -618,7 +645,7 @@ def unpack_hook(holder): try: with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad(): frame.recompute_fn(*args) - if _enable_checkpoint_early_stop: + if frame.early_stop: raise AssertionError("if early stop is enabled, we don't expect to reach here") except _StopRecomputationError: pass @@ -700,7 +727,7 @@ def recompute_fn(*inputs): recompute_context: fn(*args, **kwargs) - new_frame = _CheckpointFrame(recompute_fn) + new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop) dummy = torch.empty((0,), requires_grad=True) new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)