From 545abc292b4693fa4f308e535fddf9c90e3b1d01 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 22 Mar 2023 20:35:33 +0000 Subject: [PATCH] [aot autograd] refactor to make functionalization self-contained (#96341) This refactor should make it easier to add an export hook into aot autograd. (1) I killed `create_forward_or_joint_functionalized()` (and the functions that it called, like `forward_or_joint()`) which used to handle autograd + functionalization all-in-one-go for the joint case, and was also used in the inference case. I added a few separate helper functions: `create_functionalized_graph()`: this takes a flat fn, and returns a functionalized fx graph. It is mostly just a thin wrapper around functionalization + make_fx(), but also has some extra logic to manually append `copy_()` ops to the end of the graph. `fn_no_extra_mutations()`: this creates the fn that we want to trace in the inference code path. It takes in a function that it then calls, and returns the outputs + any (updated) mutated inputs. `joint_fn_no_external_mutations()`: this creates the fn that we want to trace in the joint code path. It takes in a function, and traces out its joint. It also does the work of cloning inputs that are mutated and require gradients, returning mutated inputs as outputs, and returning intermediate bases as outputs We should be able to add an export hook by basically adding a similar version of `joint_fn_no_external_mutations` but with a lot more restrictions (guaranteed to have no tangents, not synthetic bases, etc), and calling `create_functionalized_graph()` on it. Differential Revision: [D44204090](https://our.internmc.facebook.com/intern/diff/D44204090) Pull Request resolved: https://github.com/pytorch/pytorch/pull/96341 Approved by: https://github.com/ezyang --- torch/_functorch/aot_autograd.py | 420 +++++++++++++++---------------- 1 file changed, 199 insertions(+), 221 deletions(-) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 512a0466ef433..5010e32988764 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -833,6 +833,23 @@ def inner(*flat_args): return inner +@dataclasses.dataclass +class AOTConfig: + """ + Configuration for AOTDispatcher + """ + + fw_compiler: Callable + bw_compiler: Callable + partition_fn: Callable + decompositions: Dict[Callable, Callable] + num_params_buffers: int + aot_id: int + keep_inference_input_mutations: bool + dynamic_shapes: bool = False + aot_autograd_arg_pos_to_source : Optional[List[Source]] = None + inference_compiler: Optional[Callable] = None + # This function takes in a tensor t, and returns one of t, t.view(), or t.clone(). # When tracing the joint forward + backward, for any inputs in the graph that are mutated, # we need to clone them first (and similarly for metadata-only mutations, we need to view them first). @@ -862,181 +879,165 @@ def maybe_to_fresh_input(idx, t, meta): return t.view(t.shape) return t -# This function takes in a forward fn, runs it, and (optionally) runs autograd to compute the joint. -# When maybe_tangents is None, we only run the forward. Otherwise we run the "joint" forward + backward. -# Preconditions: -# - fn corresponds to the flattened user fw function, with duplicate inputs removed -# - functionalization is turned on (and inputs are wrapped in functional tensors) -# - Synthetic bases have been *removed* (we've taken views on them corresponding to the user argument views). -# - primals_after_cloning are what we run our forward function on. It is identical to primals_before_cloning, -# except that every input we know will be mutated in the forward has been cloned. -# We run our forward on primals_after_cloning (potentially mutating some inputs), and then compute our gradients -# w.r.t. primals_before_cloning (so we properly capture the mutation in our gradient computation). -# Importantly, due functionalization + some autograd.Function constraints, this function can return EXTRA outputs -# compared to what the original user forward returns. -# -# If we are only running the forward (and not computing the joint): -# - Our function will return (updated_inputs, fw_outs) -# -# If we are running the forward + backward (computing the joint): -# - Our function will return (updated_inputs, fw_outs, intermediate_bases), (gradients) -# -# Finally, if keep_input_mutations is set, then we will explicitly *not* return updated inputs, for any inputs -# that experienced data-only mutations. -# Instead, we are relying on the logic in create_forward_or_joint_functionalized to manually perform the input mutations, -# keeping them directly in the traced graph. -def forward_or_joint( +# This function returns a new function that returns mutated inputs as outputs. +# if keep_data_input_mutations is set, then we assume that data-only mutations +# will be left in the graph, and we only return metadata-mutated inputs as outputs. +def fn_input_mutations_to_outputs( fn: Callable, - primals_before_cloning: List[Any], - primals_after_cloning: List[Any], - maybe_tangents: Optional[List[Any]], meta: ViewAndMutationMeta, - keep_input_mutations: bool, + keep_data_input_mutations: bool, ) -> Any: - outs = fn(*primals_after_cloning) - assert len(meta.output_info) == len(outs) - - # The compiled fw will return mutated input tensors, *including* metadata-only mutation. - # However, if keep_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. - # (because data-only input mutations are handled directly in the compiled graph) - if keep_input_mutations: + def inner_fn(*args): + outs = fn(*args) + assert len(meta.output_info) == len(outs) + # The compiled fw will return mutated input tensors, *including* metadata-only mutation. + # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs. + # (because data-only input mutations are handled directly in the compiled graph) mutated_inputs_to_return = [ x - for (i, x) in enumerate(primals_after_cloning) - if meta.input_info[i].mutates_metadata + for (i, x) in enumerate(args) + if meta.input_info[i].mutates_metadata or (meta.input_info[i].mutates_data and not keep_data_input_mutations) ] - else: + return *mutated_inputs_to_return, *outs + return inner_fn + +# This function takes in a fn with external aliasing and mutation, +# and returns a new fn with no external aliasing and mutation, +# as needed for autograd. +# The main transformations are: +# - Return mutated inputs as extra outputs +# - Clone mutated inputs that require gradients, +# because autograd will require us to pass the pre-mutated inputs into autograd.grad +# - Return intermediate bases of outputs as additional outputs, +# needed to appease autograd.Function +# The new function returns: +# (1) The updated outputs +# (2) A boolean mask of len(new_fn_outputs), +# that can be used to tell autograd.grad which outputs should get tangents +# if we trace the backward. +def fn_prepped_for_autograd( + fn: Callable, + meta: ViewAndMutationMeta, +) -> Any: + def inner_fn(*args): + args_maybe_cloned = [ + maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args) + ] + + outs = fn(*args_maybe_cloned) + assert len(meta.output_info) == len(outs) + mutated_inputs_to_return = [ x - for (i, x) in enumerate(primals_after_cloning) - if meta.input_info[i].mutates_data or meta.input_info[i].mutates_metadata + for (i, x) in enumerate(args_maybe_cloned) + if meta.input_info[i].mutates_metadata or meta.input_info[i].mutates_data ] - # Case 1: We are just tracing the forward; not the joint forward + backward. - if maybe_tangents is None: - return *mutated_inputs_to_return, *outs - else: - tangents = maybe_tangents - - # Case 2: We are tracing the joint forward backward. - # This also requires us to: - # - update the graph to return intermediate bases - # - Figure out what grad_outputs to pass into the backward - # - (this includes intermediate bases in the forward, and forward inputs that had data mutations) - # - actually call autograd.grad to trace the backward. - intermediate_bases = [] - for o, info in zip(outs, meta.output_info): - if info.output_type == OutputType.alias_of_intermediate_save_as_output: - intermediate_bases.append(o._base) - - assert meta.num_intermediate_bases == len(intermediate_bases) - - # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw - # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, - # which we *should* send to grad() - outputs_for_grad = [ - x - for (i, x) in enumerate(outs) - if meta.output_info[i].output_type == OutputType.non_alias - # Also, only tensor outputs should participate in the backward - # (in particular, Symint outputs in the forward graph shouldn't get tangents) - and issubclass(meta.output_info[i].raw_type, torch.Tensor) - ] - # Pass any (non-aliased) mutated inputs in as tangents, since they'll be returned as outputs in the fw - # Important: the traced joint fw/bw will return updated inputs with data mutations, - # but *not* with metadata mutations. - # Instead, we shunt the updated metadata around externally - # and update the input's metadata outside of the autograd.Function - mutated_inputs_for_grad = [ - x - for (i, x) in enumerate(primals_after_cloning) - if meta.input_info[i].mutates_data - ] - # The tensors that we include in the backward graph are: - # - inputs that recieve *data* mutations (not metadata-only; those are recomputed later) - # - outputs that are not aliased (aliased outputs are recomputed later) - # - intermediate ._base tensors of aliased outputs (we use those later to recompute the aliased outputs) - fw_outs_to_grad = mutated_inputs_for_grad + outputs_for_grad + intermediate_bases - assert len(tangents) == len(fw_outs_to_grad) - - # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) - fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases - - # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) - # and not primals_before_cloning (the preserved inputs, pre-mutation, that we pass to grad()) - for i, arg in enumerate(primals_after_cloning): - if not isinstance(arg, Tensor): - continue - torch._sync(arg) - - # Get the inputs that need gradients - grad_primals = [] - inputs_needs_grads = [] - # Note that we're not using primals_before_cloning here, - # being carefully not to pass any mutated inputs into autograd.grad() - for p in primals_before_cloning: - is_grad_tensor = isinstance(p, Tensor) and p.requires_grad - inputs_needs_grads.append(is_grad_tensor) - if is_grad_tensor: - grad_primals.append(p) - - # Get the outputs that need gradients - needed_outs = [] - needed_tangents = [] - for out, tangent in zip(fw_outs_to_grad, tangents): - if isinstance(out, Tensor) and out.requires_grad: - # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 - # The issue is that we are sensitive to decomps that don't accurately maintain - # their output's _base.shape compared to eager mode, and this helps mitigate a bit. - needed_outs.append( - out if out.shape == tangent.shape else out.view(tangent.shape) - ) - needed_tangents.append(tangent) - - setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) - - backward_out = [] - # Call the backwards pass - if grad_primals: - with fx_traceback.preserve_node_meta(): - backward_out = torch.autograd.grad( - needed_outs, - grad_primals, - grad_outputs=needed_tangents, - allow_unused=True, - ) - backward_out_iter = iter(backward_out) - return fw_outs_to_return, [ - next(backward_out_iter) if i else None for i in inputs_needs_grads - ] + intermediate_bases = [] + for o, info in zip(outs, meta.output_info): + if info.output_type == OutputType.alias_of_intermediate_save_as_output: + intermediate_bases.append(o._base) -# This function adds extra clone() calls on any inputs in the forward that get mutated. -# It *only* does this if we plan on performing autograd on fn. -# The idea here is that when computing grdients w.r.t. inputs, we need to compute our gradients -# w.r.t. the inputs *before* they were mutated! -# Preconditions: -# - fn corresponds to the flattened user fw function, with duplicate inputs removed -# - primals **includes** synthetic bases. Importantly, if a synthetic base is mutated, -# we need to clone it *before* taking views off of it (if we clone the views they won't be views anymore) -# - functionalization is turned on (and inputs are wrapped in functional tensors) -def flat_fn_no_input_mutations( + assert meta.num_intermediate_bases == len(intermediate_bases) + + # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases) + fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases + + # Also return a boolean mask specifying which outputs to this function will be used as tangents + mutated_inputs_grad_mask = [ + meta.input_info[meta.mutated_inp_indices[i]].mutates_data + for (i, x) in enumerate(mutated_inputs_to_return) + ] + + # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw + # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead, + # which we *should* send to grad() + output_grad_mask = [ + meta.output_info[i].output_type == OutputType.non_alias + # Also, only tensor outputs should participate in the backward + # (in particular, Symint outputs in the forward graph shouldn't get tangents) + and issubclass(meta.output_info[i].raw_type, torch.Tensor) + for (i, x) in enumerate(outs) + ] + + intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))] + + out_grad_mask = mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask + assert len(out_grad_mask) == len(fw_outs_to_return) + + # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!) + # and not primals (the preserved inputs, pre-mutation, that we pass to grad()) + # This is annoying: our joint function needs to be aware of functionalization + # (syncing mutated inputs before calling autograd.grad()) + # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner. + for i, arg in enumerate(args_maybe_cloned): + if not isinstance(arg, Tensor): + continue + torch._sync(arg) + + return fw_outs_to_return, out_grad_mask + return inner_fn + +# Given a fn, computes the joint. +# NOTE: fn is expects the following behavior: +# (1) fn() needs to return a tuple of (outs, mask), +# where `mask` tells us which outputs are meant to have tangents. +# we don't know this info automatically, because we don't actually want to blindly +# compute tangents for every output that requires grad. +# Specifically, outputs that alias inputs won't participate in the backward and get tangents. +# (2) fn() cannot mutate any inputs that require gradient. +# otherwise, when we compute autograd.grad(), we will not take those input mutations into account +# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first) +def create_joint( fn: Callable, - primals: List[Any], - maybe_tangents: Optional[List[Any]], - meta: ViewAndMutationMeta, - keep_input_mutations: bool -): - # When tracing the joint fwd + bwd, making sure to clone any inputs that are mutated first. - # We need to ensure that the inputs we pass to autograd.grad() are the *original* - # inputs, and not their mutated values. - if maybe_tangents is not None: - primals_after_cloning = [ - maybe_to_fresh_input(i, t, meta) for i, t in enumerate(primals) +) -> Any: + def inner_fn(primals: List[Any], tangents: List[Any]): + outs, tangent_mask = fn(*primals) + assert len(tangent_mask) == len(outs) + outs_to_grad = [o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent] + assert len(outs_to_grad) == len(tangents) + + # Get the inputs that need gradients + grad_primals = [] + inputs_needs_grads = [] + # Note that we're not using primals here, + # being carefully not to pass any mutated inputs into autograd.grad() + for p in primals: + is_grad_tensor = isinstance(p, Tensor) and p.requires_grad + inputs_needs_grads.append(is_grad_tensor) + if is_grad_tensor: + grad_primals.append(p) + + # Get the outputs that need gradients + needed_outs = [] + needed_tangents = [] + for out, tangent in zip(outs_to_grad, tangents): + if isinstance(out, Tensor) and out.requires_grad: + # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32 + # The issue is that we are sensitive to decomps that don't accurately maintain + # their output's _base.shape compared to eager mode, and this helps mitigate a bit. + needed_outs.append( + out if out.shape == tangent.shape else out.view(tangent.shape) + ) + needed_tangents.append(tangent) + + setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs]) + + backward_out = [] + # Call the backwards pass + if grad_primals: + with fx_traceback.preserve_node_meta(): + backward_out = torch.autograd.grad( + needed_outs, + grad_primals, + grad_outputs=needed_tangents, + allow_unused=True, + ) + backward_out_iter = iter(backward_out) + return outs, [ + next(backward_out_iter) if i else None for i in inputs_needs_grads ] - else: - primals_after_cloning = primals - outs = forward_or_joint(fn, primals, primals_after_cloning, maybe_tangents, meta, keep_input_mutations) - return outs + return inner_fn # This creates the final function that we want to trace using make_fx(), # in both aot_dispatch_autograd and aot_dispatch_base. @@ -1049,29 +1050,25 @@ def flat_fn_no_input_mutations( # The function returned has signature that is either: # (1) "traced_fn(primals: List[Any])" if trace_joint is False # (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True -def create_forward_or_joint_functionalized( +def create_functionalized_graph( fn, + args, *, meta: ViewAndMutationMeta, + aot_config: AOTConfig, trace_joint: bool, - keep_input_mutations: bool ): - - def functionalized_f_helper(primals, maybe_tangents=None): - # Convention: this function is used to trace both the joint, and just the forward (for inference). - # When trace_joint is set, tangents should be passed in. - assert (maybe_tangents is not None) == trace_joint + def functionalized_f_helper(*args): # Wrap inputs into functional wrappers - f_primals = pytree.tree_map(to_fun, primals) - f_tangents = None if maybe_tangents is None else pytree.tree_map(to_fun, maybe_tangents) + f_args = pytree.tree_map(to_fun, args) torch._enable_functionalization(reapply_views=True) try: # Run the joint - f_outs = flat_fn_no_input_mutations(fn, f_primals, f_tangents, meta, keep_input_mutations) + f_outs = fn(*f_args) finally: torch._disable_functionalization() - if keep_input_mutations: + if aot_config.keep_inference_input_mutations and not trace_joint: # Note: This is a bit annoying. There's a layering issue here, where: # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs. # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs. @@ -1096,8 +1093,7 @@ def functionalized_f_helper(primals, maybe_tangents=None): # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base. # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry # about synthetic bases. - assert len(meta.input_info) == len(f_primals) - for i, (inpt_old, inpt_f) in enumerate(zip(primals, f_primals)): + for i, (inpt_old, inpt_f) in enumerate(zip(args, f_args)): if not isinstance(inpt_f, torch.Tensor): continue torch._sync(inpt_f) @@ -1112,18 +1108,16 @@ def functionalized_f_helper(primals, maybe_tangents=None): return pytree.tree_map(from_fun, f_outs) - # the joint needs have args named "primals" and "tangents", - # which are hardcoded into the partitioning logic. - def traced_joint(primals, tangents): + # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals" + # and "tangents" as its input names (which are special-cased by the partitioner) + def joint_helper(primals, tangents): return functionalized_f_helper(primals, tangents) - def traced_forward(*primals): - return functionalized_f_helper(primals) + def fwd_helper(*args): + return functionalized_f_helper(*args) - if trace_joint: - return traced_joint - else: - return traced_forward + with enable_python_dispatcher(): + return make_fx(joint_helper if trace_joint else fwd_helper, decomposition_table=aot_config.decompositions)(*args) def normalize_as_list(x): @@ -1223,23 +1217,6 @@ def call_func_with_args(f, args, steal_args=False, disable_amp=False): del guard return out -@dataclasses.dataclass -class AOTConfig: - """ - Configuration for AOTDispatcher - """ - - fw_compiler: Callable - bw_compiler: Callable - partition_fn: Callable - decompositions: Dict[Callable, Callable] - num_params_buffers: int - aot_id: int - keep_inference_input_mutations: bool - dynamic_shapes: bool = False - aot_autograd_arg_pos_to_source : Optional[List[Source]] = None - inference_compiler: Optional[Callable] = None - def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta): # aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case. # The cases that aot_dispatch_base doesn't need to handle include: @@ -1248,16 +1225,20 @@ def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, * # While cases that it does need to handle include: # - input mutations (including when inputs are aliases of each other) # - input metadata mutations - trace_fn = create_forward_or_joint_functionalized( + keep_mutations = aot_config.keep_inference_input_mutations + fn_to_trace = fn_input_mutations_to_outputs( flat_fn, + fw_metadata, + keep_data_input_mutations=aot_config.keep_inference_input_mutations, + ) + fw_module = create_functionalized_graph( + fn_to_trace, + flat_args, meta=fw_metadata, + aot_config=aot_config, trace_joint=False, - keep_input_mutations=aot_config.keep_inference_input_mutations ) - with enable_python_dispatcher(): - fw_module = make_fx(trace_fn, decomposition_table=aot_config.decompositions)(*flat_args) - if not aot_config.keep_inference_input_mutations: # As long as we opted to remove input mutations, then # there should be *NO* mutating ops in the graph at this point. @@ -2236,26 +2217,23 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, ) assert len(fw_metadata.requires_grad_info) == fw_metadata.num_mutated_inputs + fw_metadata.num_outputs + joint_inputs = (flat_args, traced_tangents) + disable_amp = torch._C._is_any_autocast_enabled() - joint_forward_backward = create_forward_or_joint_functionalized( + fn_prepared_for_autograd = fn_prepped_for_autograd( flat_fn, - meta=fw_metadata, - trace_joint=True, - # For now in the autograd case, we NEVER keep input mutations (we could eventually fix this for slightly better perf - # in some cases, but it's annoying to fix the partitioner) - keep_input_mutations=False, + fw_metadata, ) - - joint_inputs = (flat_args, traced_tangents) - - disable_amp = torch._C._is_any_autocast_enabled() + joint_fn_to_trace = create_joint(fn_prepared_for_autograd) if config.use_functionalize: - with enable_python_dispatcher(): - flattened_joints, _ = pytree.tree_flatten(joint_inputs) - fx_g = make_fx(joint_forward_backward, decomposition_table=aot_config.decompositions)( - *joint_inputs - ) + fx_g = create_functionalized_graph( + joint_fn_to_trace, + joint_inputs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=True, + ) # There should be *NO* mutating ops in the graph at this point. assert_functional_graph(fx_g.graph)