Skip to content

Commit

Permalink
[aot autograd] refactor to make functionalization self-contained (pyt…
Browse files Browse the repository at this point in the history
…orch#96341)

This refactor should make it easier to add an export hook into aot autograd.

(1) I killed `create_forward_or_joint_functionalized()` (and the functions that it called, like `forward_or_joint()`) which used to handle autograd + functionalization all-in-one-go for the joint case, and was also used in the inference case.

I added a few separate helper functions:

`create_functionalized_graph()`: this takes a flat fn, and returns a functionalized fx graph. It is mostly just a thin wrapper around functionalization + make_fx(), but also has some extra logic to manually append `copy_()` ops to the end of the graph.

`fn_no_extra_mutations()`: this creates the fn that we want to trace in the inference code path. It takes in a function that it then calls, and returns the outputs + any (updated) mutated inputs.

`joint_fn_no_external_mutations()`: this creates the fn that we want to trace in the joint code path. It takes in a function, and traces out its joint. It also does the work of cloning inputs that are mutated and require gradients, returning mutated inputs as outputs, and returning intermediate bases as outputs

We should be able to add an export hook by basically adding a similar version of `joint_fn_no_external_mutations` but with a lot more restrictions (guaranteed to have no tangents, not synthetic bases, etc), and calling `create_functionalized_graph()` on it.

Differential Revision: [D44204090](https://our.internmc.facebook.com/intern/diff/D44204090)
Pull Request resolved: pytorch#96341
Approved by: https://github.com/ezyang
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Mar 22, 2023
1 parent e8a722b commit 545abc2
Showing 1 changed file with 199 additions and 221 deletions.
Loading

0 comments on commit 545abc2

Please sign in to comment.