Skip to content

Commit

Permalink
Dont increment generation if forward of backward exists, and warning …
Browse files Browse the repository at this point in the history
…on deallocation of live tensors (pytorch#97168)

Refining the logic for when it is okay to ignore previously live outputs from cudagraphs. If there is a forward that has been invoked without invocation of the corresponding backwards, dont allow overwriting outputs.

Differential Revision: [D44228369](https://our.internmc.facebook.com/intern/diff/D44228369)
Pull Request resolved: pytorch#97168
Approved by: https://github.com/ezyang, https://github.com/jansel
  • Loading branch information
eellison authored and pytorchmergebot committed Mar 22, 2023
1 parent 9370f25 commit 9c144bc
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 15 deletions.
40 changes: 38 additions & 2 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]

def cudagraphify_impl(self, *args, **kwargs):
return tree_cudagraphify_impl(*args, **kwargs, device_index=self.device_idx)
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_backward=False,
is_inference=True,
)

@staticmethod
def run_twc(fn, *args, **kwargs):
Expand Down Expand Up @@ -517,7 +523,9 @@ def foo(args):

inp = torch.rand([20, 20], device="cuda:1")

foo_cg = tree_cudagraphify_impl(foo, [inp], (), device_index=1)
foo_cg = tree_cudagraphify_impl(
foo, [inp], (), device_index=1, is_backward=False, is_inference=True
)
self.assertEqual(foo_cg([inp]), foo([inp]))

self.assertTrue(self.get_manager(device_index=0) is None)
Expand All @@ -526,6 +534,34 @@ def foo(args):
test()
self.assertTrue(self.get_manager(device_index=1) is None)

def test_forward_generation(self):
def foo(x):
return x * x * x

def foo2(x):
return x * 12

foo_opt = torch._dynamo.optimize()(foo)
foo2_opt = torch._dynamo.optimize()(foo2)
ones = torch.ones([4, 4], device="cuda", requires_grad=True)

out = foo_opt(ones)
out2 = foo2_opt(out)

self.assertEqual(all_live_block_count(), 2)

self.assertEqual(self.get_manager().forwards_with_pending_backwards, 2)

out2.sum().backward()

self.assertEqual(self.get_manager().forwards_with_pending_backwards, 0)

del out
del out2

out = foo_opt(ones.detach())
self.assertEqual(self.get_manager().forwards_with_pending_backwards, 0)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def _wrapped_bw_compiler(*args, **kwargs):

bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
kwargs["inference_compiler"] = (
kwargs.get("inference_compiler") or kwargs["fw_compiler"]
)

from functorch.compile import nop

Expand Down
20 changes: 17 additions & 3 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ class AOTConfig:
keep_inference_input_mutations: bool
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source : Optional[List[Source]] = None
inference_compiler: Optional[Callable] = None

def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
Expand Down Expand Up @@ -1274,7 +1275,8 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *
context = disable_autocast_manager if disable_amp else nullcontext

with context(), track_graph_compiling(aot_config, "inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
compiler = aot_config.inference_compiler if aot_config.inference_compiler is not None else aot_config.fw_compiler
compiled_fw = compiler(fw_module, flat_args)

compiled_fn = create_runtime_wrapper(
compiled_fw,
Expand Down Expand Up @@ -2509,7 +2511,8 @@ def call_compiled_backward():
aot_config.bw_compiler, None, None,
aot_config.decompositions, 0, aot_config.aot_id,
aot_config.keep_inference_input_mutations,
aot_config.dynamic_shapes
aot_config.dynamic_shapes,
inference_compiler=None,
)
)
else:
Expand Down Expand Up @@ -2747,6 +2750,7 @@ def aot_function(
hasher_type=None, # deprecated
static_argnums: Optional[Tuple[int]] = None, # deprecated
keep_inference_input_mutations: bool = False,
inference_compiler: Optional[Callable] = None,
*,
# Whether or not to trace with dynamic shapes
dynamic=False,
Expand Down Expand Up @@ -2784,7 +2788,10 @@ def aot_function(
backward graphs.
decompositions (Dict): A dictionary to define the decomposition of
larger Aten ops into simpler or core Aten ops.
inference_compiler (Optional[Callable]): A Python function that accepts an
Fx graph with Aten ops and input args, and returns a Callable that
semantically is equivalent to the input Fx graph. Default: None
(when None, it defaults to the :attr:`fw_compiler`)
Returns:
Returns a ``Callable`` that retains the eager behavior of the original
:attr:`fn`, but with forward and backward graph compiled via
Expand All @@ -2808,9 +2815,12 @@ def aot_function(

if bw_compiler is None:
bw_compiler = fw_compiler
if inference_compiler is None:
inference_compiler = fw_compiler
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=fw_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=num_params_buffers,
Expand Down Expand Up @@ -2934,6 +2944,7 @@ def aot_module_simplified(
hasher_type=None,
static_argnums=None,
keep_inference_input_mutations=False,
inference_compiler: Optional[Callable] = None,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
Expand Down Expand Up @@ -2999,6 +3010,8 @@ def functional_call(*args, **kwargs):
assert static_argnums is None
if bw_compiler is None:
bw_compiler = fw_compiler
if inference_compiler is None:
inference_compiler = fw_compiler

full_args = []
# First, the params
Expand Down Expand Up @@ -3040,6 +3053,7 @@ def functional_call(*args, **kwargs):
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=params_len,
Expand Down
28 changes: 25 additions & 3 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def compile_fx_inner(
is_backward=False,
graph_id=None,
aot_mode=False,
is_inference=False,
):
if is_tf32_warning_applicable(gm):
_warn_tf32_disabled()
Expand All @@ -162,6 +163,7 @@ def compile_fx_inner(
cudagraphs = config.triton.cudagraphs

shape_env = _shape_env_from_inputs(example_inputs)

fake_mode = fake_mode_from_tensors(
example_inputs
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
Expand All @@ -183,6 +185,8 @@ def compile_fx_inner(
if aot_mode:
return compiled_fn

output = list(gm.graph.nodes)[-1]
assert len(output.args) == 1
if cudagraphs:
complex_memory_overlap_inputs = any(
complex_memory_overlap(t) for t in example_inputs
Expand All @@ -200,6 +204,8 @@ def compile_fx_inner(
example_inputs,
static_input_idxs=range(num_fixed),
device_index=next(iter(graph.device_idxs)),
is_backward=is_backward,
is_inference=is_inference,
)
else:
BoxedBool.disable(cudagraphs)
Expand Down Expand Up @@ -267,14 +273,25 @@ def run(new_inputs):


@dynamo_utils.dynamo_timed
def cudagraphify(model, inputs, static_input_idxs=(), *, device_index: int):
def cudagraphify(
model,
inputs,
static_input_idxs=(),
*,
device_index: int,
is_backward: bool,
is_inference: bool,
):
from torch._inductor.cudagraph_trees import (
cudagraphify_impl as new_cudagraphify_impl,
)

if config.triton.cudagraph_trees:
cudagraphify_fn = functools.partial(
new_cudagraphify_impl, device_index=device_index
new_cudagraphify_impl,
device_index=device_index,
is_backward=is_backward,
is_inference=is_inference,
)
else:
cudagraphify_fn = cudagraphify_impl
Expand Down Expand Up @@ -493,7 +510,7 @@ def compile_fx(
graph_id = next(_graph_counter)

@dynamo_utils.dynamo_timed
def fw_compiler(model: torch.fx.GraphModule, example_inputs):
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
fixed = len(example_inputs) - num_example_inputs
# Why convert outplace op to inplace? Inductor can support inplace operations well and for custom
# inplace ops which are lowered as ExternKernel, it is beneficial to performance when the inplace
Expand All @@ -505,8 +522,12 @@ def fw_compiler(model: torch.fx.GraphModule, example_inputs):
num_fixed=fixed,
cudagraphs=cudagraphs,
graph_id=graph_id,
is_inference=is_inference,
)

fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)

# Save and restore dynamic shapes setting for backwards, as it is
# sometimes done as a context manager which won't be set when we
# hit backwards compile
Expand Down Expand Up @@ -534,6 +555,7 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs):
return aot_autograd(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
decompositions=decompositions,
partition_fn=functools.partial(
min_cut_rematerialization_partition, compiler="inductor"
Expand Down
75 changes: 68 additions & 7 deletions torch/_inductor/cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,23 @@ def get_container(device_index: int):
return container_dict[device_index]


def cudagraphify_impl(model, inputs, static_input_idxs=(), *, device_index: int):
def cudagraphify_impl(
model,
inputs,
static_input_idxs=(),
*,
device_index: int,
is_backward: bool,
is_inference: bool,
):
manager = get_container(device_index).get_tree_manager()
return manager.add_function(model, inputs, static_input_idxs)
return manager.add_function(
model,
inputs,
static_input_idxs,
is_backward,
is_inference,
)


def is_live(weak_ref):
Expand Down Expand Up @@ -848,12 +862,20 @@ def data_ptrs_dead_since_invocation(self) -> List[int]:
return ptrs_to_deallocate

def path_live_weakrefs(self) -> Generator[StorageWeakRefWrapper]:
"Returns all live storages weakrefs that created by nodes in this path"
for i, j in self.live_indices_after_graph:
out = self.path_weakrefs[i][j]
if is_live(out):
yield out

def path_live_weakrefs_and_stacktraces(
self,
) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]:
"Returns all live storages weakrefs that created by nodes in this path"
for i, j in self.live_indices_after_graph:
out = self.path_weakrefs[i][j]
if is_live(out):
yield out, self.path_stacktraces[i][j]

def clear_path_outputs(self):
"Clear the output lists of all nodes in the path"
for li in self.path_weakrefs:
Expand Down Expand Up @@ -1003,6 +1025,12 @@ class ExecutionState(Enum):
EXECUTION = auto()


class CompilationMode(Enum):
FORWARD = auto()
BACKWARD = auto()
INFERENCE = auto()


class CUDAGraphTreeManager:
"""
Groups individual recordings or executions of cuda graphs into a tree of recordings,
Expand Down Expand Up @@ -1083,7 +1111,18 @@ def __init__(self, device_index: int):
# number of instances we had to checkpoint the function
self.debug_checkpointing_counter = 0

self.id_to_mode: Dict[int, CompilationMode] = {}

# forwards that have been invoked without invocation of their corresponding backwards
self.forwards_with_pending_backwards: int = 0

def run(self, new_inputs: List[Tensor], function_id: FunctionID):
mode = self.id_to_mode[function_id]
if mode == CompilationMode.FORWARD:
self.forwards_with_pending_backwards += 1
elif mode == CompilationMode.BACKWARD:
self.forwards_with_pending_backwards -= 1

# we will try to end the current execution lazily, since
# we dont want to do unnecessary checking of the existing outputs
# on the hot path, but both recording and warmup only happen once
Expand Down Expand Up @@ -1196,11 +1235,27 @@ def new_graph_id(self) -> GraphID:
def new_func_id(self) -> FunctionID:
return FunctionID(next(self.func_counter))

def add_function(self, model, inputs, static_input_idxs) -> Callable:
def add_function(
self,
model,
inputs,
static_input_idxs,
is_backward,
is_inference,
) -> Callable:
id = self.new_func_id()
self.ids_to_funcs[id] = WrappedFunction(
model, remove_unaligned_input_idxs(inputs, static_input_idxs), id
)
self.id_to_mode[id] = (
CompilationMode.BACKWARD
if is_backward
else (
CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD
)
)

comp_context = torch._functorch.aot_autograd.get_graph_being_compiled()
fn = functools.partial(self.run, function_id=id)

# container needs to set clean up when fn dies
Expand Down Expand Up @@ -1234,6 +1289,12 @@ def current_node(self, value):
def get_curr_generation() -> int:
return GenerationTracker.generation

def can_start_new_generation(self) -> bool:
if self.forwards_with_pending_backwards != 0:
return False

return self.current_gen != self.get_curr_generation()

def try_end_curr_recording(self) -> None:
"""
Check if the current recording can be terminated, either because all outputs of the
Expand All @@ -1244,7 +1305,7 @@ def try_end_curr_recording(self) -> None:
assert self.current_node is not None

# multiple invocations, allow overwriting the previous generation
if self.current_gen != self.get_curr_generation():
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.clear_current_node_outputs_and_set_to_none()
return
Expand All @@ -1264,15 +1325,15 @@ def try_end_curr_execution(self) -> None:
if self.current_node is None:
return

if self.current_gen != self.get_curr_generation():
if self.can_start_new_generation():
self.clear_current_node_outputs_and_set_to_none()
return

if self.current_node.all_outputs_are_dead():
self.clear_current_node_outputs_and_set_to_none()

def try_end_curr_warmup(self):
if self.current_gen != self.get_curr_generation():
if self.can_start_new_generation():
self.dealloc_current_path_weakrefs()
self.current_node = None
return
Expand Down

0 comments on commit 9c144bc

Please sign in to comment.