diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 8de34efe968db9..9a4f3c5c8d08f7 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -125,7 +125,13 @@ def get_root_children(self): return [root.num_descendants() for root in self.get_roots()] def cudagraphify_impl(self, *args, **kwargs): - return tree_cudagraphify_impl(*args, **kwargs, device_index=self.device_idx) + return tree_cudagraphify_impl( + *args, + **kwargs, + device_index=self.device_idx, + is_backward=False, + is_inference=True, + ) @staticmethod def run_twc(fn, *args, **kwargs): @@ -517,7 +523,9 @@ def foo(args): inp = torch.rand([20, 20], device="cuda:1") - foo_cg = tree_cudagraphify_impl(foo, [inp], (), device_index=1) + foo_cg = tree_cudagraphify_impl( + foo, [inp], (), device_index=1, is_backward=False, is_inference=True + ) self.assertEqual(foo_cg([inp]), foo([inp])) self.assertTrue(self.get_manager(device_index=0) is None) @@ -526,6 +534,34 @@ def foo(args): test() self.assertTrue(self.get_manager(device_index=1) is None) + def test_forward_generation(self): + def foo(x): + return x * x * x + + def foo2(x): + return x * 12 + + foo_opt = torch._dynamo.optimize()(foo) + foo2_opt = torch._dynamo.optimize()(foo2) + ones = torch.ones([4, 4], device="cuda", requires_grad=True) + + out = foo_opt(ones) + out2 = foo2_opt(out) + + self.assertEqual(all_live_block_count(), 2) + + self.assertEqual(self.get_manager().forwards_with_pending_backwards, 2) + + out2.sum().backward() + + self.assertEqual(self.get_manager().forwards_with_pending_backwards, 0) + + del out + del out2 + + out = foo_opt(ones.detach()) + self.assertEqual(self.get_manager().forwards_with_pending_backwards, 0) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index e3d4078fd27e88..55c8d5066d4d4c 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -41,6 +41,9 @@ def _wrapped_bw_compiler(*args, **kwargs): bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] kwargs["bw_compiler"] = _wrapped_bw_compiler + kwargs["inference_compiler"] = ( + kwargs.get("inference_compiler") or kwargs["fw_compiler"] + ) from functorch.compile import nop diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 8670f3ef364d68..e7704d6275976f 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1238,6 +1238,7 @@ class AOTConfig: keep_inference_input_mutations: bool dynamic_shapes: bool = False aot_autograd_arg_pos_to_source : Optional[List[Source]] = None + inference_compiler: Optional[Callable] = None def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta): # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. @@ -1274,7 +1275,8 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, * context = disable_autocast_manager if disable_amp else nullcontext with context(), track_graph_compiling(aot_config, "inference"): - compiled_fw = aot_config.fw_compiler(fw_module, flat_args) + compiler = aot_config.inference_compiler if aot_config.inference_compiler is not None else aot_config.fw_compiler + compiled_fw = compiler(fw_module, flat_args) compiled_fn = create_runtime_wrapper( compiled_fw, @@ -2509,7 +2511,8 @@ def call_compiled_backward(): aot_config.bw_compiler, None, None, aot_config.decompositions, 0, aot_config.aot_id, aot_config.keep_inference_input_mutations, - aot_config.dynamic_shapes + aot_config.dynamic_shapes, + inference_compiler=None, ) ) else: @@ -2747,6 +2750,7 @@ def aot_function( hasher_type=None, # deprecated static_argnums: Optional[Tuple[int]] = None, # deprecated keep_inference_input_mutations: bool = False, + inference_compiler: Optional[Callable] = None, *, # Whether or not to trace with dynamic shapes dynamic=False, @@ -2784,7 +2788,10 @@ def aot_function( backward graphs. decompositions (Dict): A dictionary to define the decomposition of larger Aten ops into simpler or core Aten ops. - + inference_compiler (Optional[Callable]): A Python function that accepts an + Fx graph with Aten ops and input args, and returns a Callable that + semantically is equivalent to the input Fx graph. Default: None + (when None, it defaults to the :attr:`fw_compiler`) Returns: Returns a ``Callable`` that retains the eager behavior of the original :attr:`fn`, but with forward and backward graph compiled via @@ -2808,9 +2815,12 @@ def aot_function( if bw_compiler is None: bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler aot_config = AOTConfig( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + inference_compiler=fw_compiler, partition_fn=partition_fn, decompositions=decompositions, num_params_buffers=num_params_buffers, @@ -2934,6 +2944,7 @@ def aot_module_simplified( hasher_type=None, static_argnums=None, keep_inference_input_mutations=False, + inference_compiler: Optional[Callable] = None, ) -> nn.Module: """ This is the simplified or low overhead version of aot_module. For frontends @@ -2999,6 +3010,8 @@ def functional_call(*args, **kwargs): assert static_argnums is None if bw_compiler is None: bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler full_args = [] # First, the params @@ -3040,6 +3053,7 @@ def functional_call(*args, **kwargs): aot_config = AOTConfig( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + inference_compiler=inference_compiler, partition_fn=partition_fn, decompositions=decompositions, num_params_buffers=params_len, diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 40a14543b0891a..6c3bd95e0a7537 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -139,6 +139,7 @@ def compile_fx_inner( is_backward=False, graph_id=None, aot_mode=False, + is_inference=False, ): if is_tf32_warning_applicable(gm): _warn_tf32_disabled() @@ -162,6 +163,7 @@ def compile_fx_inner( cudagraphs = config.triton.cudagraphs shape_env = _shape_env_from_inputs(example_inputs) + fake_mode = fake_mode_from_tensors( example_inputs ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) @@ -183,6 +185,8 @@ def compile_fx_inner( if aot_mode: return compiled_fn + output = list(gm.graph.nodes)[-1] + assert len(output.args) == 1 if cudagraphs: complex_memory_overlap_inputs = any( complex_memory_overlap(t) for t in example_inputs @@ -200,6 +204,8 @@ def compile_fx_inner( example_inputs, static_input_idxs=range(num_fixed), device_index=next(iter(graph.device_idxs)), + is_backward=is_backward, + is_inference=is_inference, ) else: BoxedBool.disable(cudagraphs) @@ -267,14 +273,25 @@ def run(new_inputs): @dynamo_utils.dynamo_timed -def cudagraphify(model, inputs, static_input_idxs=(), *, device_index: int): +def cudagraphify( + model, + inputs, + static_input_idxs=(), + *, + device_index: int, + is_backward: bool, + is_inference: bool, +): from torch._inductor.cudagraph_trees import ( cudagraphify_impl as new_cudagraphify_impl, ) if config.triton.cudagraph_trees: cudagraphify_fn = functools.partial( - new_cudagraphify_impl, device_index=device_index + new_cudagraphify_impl, + device_index=device_index, + is_backward=is_backward, + is_inference=is_inference, ) else: cudagraphify_fn = cudagraphify_impl @@ -493,7 +510,7 @@ def compile_fx( graph_id = next(_graph_counter) @dynamo_utils.dynamo_timed - def fw_compiler(model: torch.fx.GraphModule, example_inputs): + def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference): fixed = len(example_inputs) - num_example_inputs # Why convert outplace op to inplace? Inductor can support inplace operations well and for custom # inplace ops which are lowered as ExternKernel, it is beneficial to performance when the inplace @@ -505,8 +522,12 @@ def fw_compiler(model: torch.fx.GraphModule, example_inputs): num_fixed=fixed, cudagraphs=cudagraphs, graph_id=graph_id, + is_inference=is_inference, ) + fw_compiler = functools.partial(fw_compiler_base, is_inference=False) + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + # Save and restore dynamic shapes setting for backwards, as it is # sometimes done as a context manager which won't be set when we # hit backwards compile @@ -534,6 +555,7 @@ def bw_compiler(model: torch.fx.GraphModule, example_inputs): return aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + inference_compiler=inference_compiler, decompositions=decompositions, partition_fn=functools.partial( min_cut_rematerialization_partition, compiler="inductor" diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 4b5c6e3d460c68..f3c58df95ed015 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -250,9 +250,23 @@ def get_container(device_index: int): return container_dict[device_index] -def cudagraphify_impl(model, inputs, static_input_idxs=(), *, device_index: int): +def cudagraphify_impl( + model, + inputs, + static_input_idxs=(), + *, + device_index: int, + is_backward: bool, + is_inference: bool, +): manager = get_container(device_index).get_tree_manager() - return manager.add_function(model, inputs, static_input_idxs) + return manager.add_function( + model, + inputs, + static_input_idxs, + is_backward, + is_inference, + ) def is_live(weak_ref): @@ -848,12 +862,20 @@ def data_ptrs_dead_since_invocation(self) -> List[int]: return ptrs_to_deallocate def path_live_weakrefs(self) -> Generator[StorageWeakRefWrapper]: - "Returns all live storages weakrefs that created by nodes in this path" for i, j in self.live_indices_after_graph: out = self.path_weakrefs[i][j] if is_live(out): yield out + def path_live_weakrefs_and_stacktraces( + self, + ) -> Generator[Tuple[StorageWeakRefWrapper, Optional[str]]]: + "Returns all live storages weakrefs that created by nodes in this path" + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if is_live(out): + yield out, self.path_stacktraces[i][j] + def clear_path_outputs(self): "Clear the output lists of all nodes in the path" for li in self.path_weakrefs: @@ -1003,6 +1025,12 @@ class ExecutionState(Enum): EXECUTION = auto() +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + class CUDAGraphTreeManager: """ Groups individual recordings or executions of cuda graphs into a tree of recordings, @@ -1083,7 +1111,18 @@ def __init__(self, device_index: int): # number of instances we had to checkpoint the function self.debug_checkpointing_counter = 0 + self.id_to_mode: Dict[int, CompilationMode] = {} + + # forwards that have been invoked without invocation of their corresponding backwards + self.forwards_with_pending_backwards: int = 0 + def run(self, new_inputs: List[Tensor], function_id: FunctionID): + mode = self.id_to_mode[function_id] + if mode == CompilationMode.FORWARD: + self.forwards_with_pending_backwards += 1 + elif mode == CompilationMode.BACKWARD: + self.forwards_with_pending_backwards -= 1 + # we will try to end the current execution lazily, since # we dont want to do unnecessary checking of the existing outputs # on the hot path, but both recording and warmup only happen once @@ -1196,11 +1235,27 @@ def new_graph_id(self) -> GraphID: def new_func_id(self) -> FunctionID: return FunctionID(next(self.func_counter)) - def add_function(self, model, inputs, static_input_idxs) -> Callable: + def add_function( + self, + model, + inputs, + static_input_idxs, + is_backward, + is_inference, + ) -> Callable: id = self.new_func_id() self.ids_to_funcs[id] = WrappedFunction( model, remove_unaligned_input_idxs(inputs, static_input_idxs), id ) + self.id_to_mode[id] = ( + CompilationMode.BACKWARD + if is_backward + else ( + CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD + ) + ) + + comp_context = torch._functorch.aot_autograd.get_graph_being_compiled() fn = functools.partial(self.run, function_id=id) # container needs to set clean up when fn dies @@ -1234,6 +1289,12 @@ def current_node(self, value): def get_curr_generation() -> int: return GenerationTracker.generation + def can_start_new_generation(self) -> bool: + if self.forwards_with_pending_backwards != 0: + return False + + return self.current_gen != self.get_curr_generation() + def try_end_curr_recording(self) -> None: """ Check if the current recording can be terminated, either because all outputs of the @@ -1244,7 +1305,7 @@ def try_end_curr_recording(self) -> None: assert self.current_node is not None # multiple invocations, allow overwriting the previous generation - if self.current_gen != self.get_curr_generation(): + if self.can_start_new_generation(): self.dealloc_current_path_weakrefs() self.clear_current_node_outputs_and_set_to_none() return @@ -1264,7 +1325,7 @@ def try_end_curr_execution(self) -> None: if self.current_node is None: return - if self.current_gen != self.get_curr_generation(): + if self.can_start_new_generation(): self.clear_current_node_outputs_and_set_to_none() return @@ -1272,7 +1333,7 @@ def try_end_curr_execution(self) -> None: self.clear_current_node_outputs_and_set_to_none() def try_end_curr_warmup(self): - if self.current_gen != self.get_curr_generation(): + if self.can_start_new_generation(): self.dealloc_current_path_weakrefs() self.current_node = None return