Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[aot autograd] refactor to make functionalization self-contained (pyt…
…orch#96341) This refactor should make it easier to add an export hook into aot autograd. (1) I killed `create_forward_or_joint_functionalized()` (and the functions that it called, like `forward_or_joint()`) which used to handle autograd + functionalization all-in-one-go for the joint case, and was also used in the inference case. I added a few separate helper functions: `create_functionalized_graph()`: this takes a flat fn, and returns a functionalized fx graph. It is mostly just a thin wrapper around functionalization + make_fx(), but also has some extra logic to manually append `copy_()` ops to the end of the graph. `fn_no_extra_mutations()`: this creates the fn that we want to trace in the inference code path. It takes in a function that it then calls, and returns the outputs + any (updated) mutated inputs. `joint_fn_no_external_mutations()`: this creates the fn that we want to trace in the joint code path. It takes in a function, and traces out its joint. It also does the work of cloning inputs that are mutated and require gradients, returning mutated inputs as outputs, and returning intermediate bases as outputs We should be able to add an export hook by basically adding a similar version of `joint_fn_no_external_mutations` but with a lot more restrictions (guaranteed to have no tangents, not synthetic bases, etc), and calling `create_functionalized_graph()` on it. Differential Revision: [D44204090](https://our.internmc.facebook.com/intern/diff/D44204090) Pull Request resolved: pytorch#96341 Approved by: https://github.com/ezyang
- Loading branch information