Skip to content

Commit

Permalink
Make early stop the default for checkpoint and expose a way to disable (
Browse files Browse the repository at this point in the history
pytorch#96866)

Why did I choose context manager instead of per-call? Early stopping is not part of the model definition, and depending on how a particular model is used, e.g., with PT2 or not we may or may not want to disable early stopping.
Pull Request resolved: pytorch#96866
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and pytorchmergebot committed Mar 22, 2023
1 parent 546835c commit 7a8b691
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 21 deletions.
118 changes: 108 additions & 10 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5606,18 +5606,18 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

def context_fn():
return verbose_mode, contextlib.nullcontext()
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
self.assertEqual(verbose_mode.operators, ['sin.default'])
out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn)
self.assertEqual(verbose_mode.operators, ['exp.default'])

verbose_mode.operators = []

def context_fn():
return contextlib.nullcontext(), verbose_mode
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
out = checkpoint(lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn)
out.backward()
self.assertEqual(
verbose_mode.operators,
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
['exp.default', 'detach.default', 'detach.default']
)

with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
Expand Down Expand Up @@ -10677,7 +10677,7 @@ def scope():

@parametrize("early_stop", [True, False])
def test_nested_checkpoint(self, early_stop):
with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
x = torch.randn((), requires_grad=True)

def f(x):
Expand All @@ -10703,7 +10703,7 @@ def g(x):

@parametrize("early_stop", [True, False])
def test_nested_checkpoint_two_children(self, early_stop):
with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
grad, sum, c = self.grad, self.sum, self.checkpoint

def f(x):
Expand Down Expand Up @@ -10742,7 +10742,7 @@ def fn(k, a, b, f):
def f(x):
return x.sin()

with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False)
actual_grads = torch.autograd.grad(out, (a, b))

Expand All @@ -10762,7 +10762,7 @@ def fn(a, blah=None):
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
out = checkpoint(fn, a, blah=b, use_reentrant=False)
actual_grads = torch.autograd.grad(out, (a, b))

Expand All @@ -10783,7 +10783,7 @@ def fn(a):

a = torch.tensor(1., requires_grad=True)

with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
out = checkpoint(fn, a, use_reentrant=False)
# The hook is registered on the original graph
out.grad_fn.next_functions[0][0].register_hook(hook)
Expand All @@ -10805,11 +10805,109 @@ def hook(*_unused_args):
x.backward(retain_graph=True)

a = torch.tensor(1., requires_grad=True)
with torch.utils.checkpoint._set_checkpoint_early_stop(early_stop):
with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
x, out = checkpoint(fn, a, use_reentrant=False)
out.grad_fn.register_hook(hook)
out.backward(retain_graph=True)

def test_nested_checkpoint_set_early_stop(self):
counter = [0]

def clone(x):
counter[0] += 1
return x.clone()

def fn(x):
# Since clone does not save anything, it is not recomputed iff
# early stop is enabled.
return clone(x.sin().cos())

# Early stopping is enabled by default
a = torch.tensor(1., requires_grad=True)
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)

# Try using the context manager to set early stopping to False.
# Expect early stopping to be disabled for all checkpoints ran under
# the context manager, even though context manager is no longer active
# when backward/recomputation is performed.
counter = [0]
a = torch.tensor(1., requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False)

out.backward()
self.assertEqual(counter[0], 2)

def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
# Case 1: We have one tensor saved and its the input

# We have two different counters here because in this case we actually
# do call into x.sin() at the python level during recomputation whether
# or not early stop is enabled. This is because the early stopping
# only happens at the autograd level (preventing us from reaching the
# backend).
python_dispatch_counter = [0]
counter = [0]

class SinCounterMode(TorchDispatchMode):
def __init__(self):
self.count = 0

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.ops.aten.sin.default:
self.count += 1
return func(*args, **kwargs)

def fn(x):
counter[0] += 1
return x.sin()

# With early stopping (enabled by default)
a = torch.tensor(1., requires_grad=True)
with SinCounterMode() as python_dispatch_counter:
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 2)
self.assertEqual(python_dispatch_counter.count, 1)

# Without early stopping
counter = [0]
a = torch.tensor(1., requires_grad=True)
with SinCounterMode() as python_dispatch_counter:
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 2)
self.assertEqual(python_dispatch_counter.count, 2)

# Case 2: Forward saves no tensors

# Since unpack isn't even called, counter is 1 whether or not early stop
# is enabled!
counter = [0]

def fn2(x):
counter[0] += 1
return x.clone()

# With early stopping (enabled by default)
a = torch.tensor(1., requires_grad=True)
out = checkpoint(fn2, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)

# Without early stopping
counter = [0]
a = torch.tensor(1., requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = checkpoint(fn2, a, use_reentrant=False)
out.backward()
self.assertEqual(counter[0], 1)


class TestAutogradMultipleDispatch(TestCase):
def test_autograd_multiple_dispatch_registrations(self, device):
t = torch.randn(3, 3, device=device, requires_grad=True)
Expand Down
49 changes: 38 additions & 11 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states", "noop_context_fn"
"set_device_states", "noop_context_fn", "set_checkpoint_early_stop"
]

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -208,7 +208,13 @@ def checkpoint(
the non-reentrant variant of checkpoint (``use_reentrant=False``)
differ in the following ways:
* The reentrant variant does not record the autograd graph during the
* Non-reentrant checkpoint stops recomputation as soon as all needed
intermediate activations have been recomputed. This feature is enabled
by default, but can be disabled with :func:`set_checkpoint_early_stop`.
Reentrant checkpoint always recomputes :attr:`function` in its
entirety during the backward pass.
* The reentrant variant does not record the autograd graph during the
forward pass, as it runs with the forward pass under
:func:`torch.no_grad`. The non-reentrant version does record the
autograd graph, allowing one to perform backward on the graph within
Expand Down Expand Up @@ -492,13 +498,31 @@ def forward(input):
# We save x and w, however.
# 7. Continue with returning

# NB: This is temporary and should be removed in a follow up PR. Early stopping
# is currently disabled by default. Since some nested test cases require
# ealry stopping to pass, _set_checkpoint_early_stop can be used to enable.
_enable_checkpoint_early_stop = False
_enable_checkpoint_early_stop = True

@contextlib.contextmanager
def _set_checkpoint_early_stop(enable):
def set_checkpoint_early_stop(enable: bool):
"""Context manager that sets whether checkpoint should stop recomputation
early.
By default, non-reentrant checkpoint stops recomputation as soon as it
has computed all needed Tensors. This context manager can be used to disable
that feature if it is problematic for your specific application.
This context manager only needs to be active when forward is run. It does
not need to be active during backward.
Example::
>>> # xdoctest: +SKIP(failing)
>>> message = "saved tensors default hooks are disabled"
>>> with set_checkpoint_early_stop(False):
... # Any checkpoint under this context manager will respect this
... # context manager, even if its backward is performed outside.
... out = checkpoint(fn, inputs)
...
>>> out.backward()
"""
global _enable_checkpoint_early_stop
try:
prev = _enable_checkpoint_early_stop
Expand Down Expand Up @@ -546,7 +570,7 @@ def backward(ctx, *grad_outputs):
raise AssertionError("Did not expect to backward on this graph")

class _CheckpointFrame():
def __init__(self, recompute_fn):
def __init__(self, recompute_fn, early_stop):
self.recompute_fn = recompute_fn
self.input_saver = None
self.weak_holders: List[ReferenceType] = []
Expand All @@ -560,6 +584,9 @@ def __init__(self, recompute_fn):
self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)

# See Rule 5
self.early_stop = early_stop

# See Rule 5
class _StopRecomputationError(Exception):
pass
Expand All @@ -584,7 +611,7 @@ def pack_hook(x):
holder.handles[gid] = _Handle()
target_frame.recomputed[gid][holder.handles[gid]] = x.detach()

if _enable_checkpoint_early_stop and \
if target_frame.early_stop and \
target_frame.recomp_counter[gid] == len(target_frame.weak_holders):
raise _StopRecomputationError()
# See Rule 6: [ Basic case ] above
Expand Down Expand Up @@ -618,7 +645,7 @@ def unpack_hook(holder):
try:
with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad():
frame.recompute_fn(*args)
if _enable_checkpoint_early_stop:
if frame.early_stop:
raise AssertionError("if early stop is enabled, we don't expect to reach here")
except _StopRecomputationError:
pass
Expand Down Expand Up @@ -700,7 +727,7 @@ def recompute_fn(*inputs):
recompute_context:
fn(*args, **kwargs)

new_frame = _CheckpointFrame(recompute_fn)
new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop)
dummy = torch.empty((0,), requires_grad=True)
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)

Expand Down

0 comments on commit 7a8b691

Please sign in to comment.