-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto upgrade to autodiff_deferred in nested AD #1839
Conversation
@aviatesk mind taking a quick look at this to see if it makes sense/should work (locally it runs successfully!) Basically we want to upgrade autodiff into autodiff_deferred when running within an EnzymeInterpreter |
Does that mean |
Likely not since it is still necessary for the GPU interface. |
if f === Enzyme.autodiff && length(argtypes) >= 4 | ||
if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} | ||
arginfo2 = ArgInfo( | ||
fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You must actually change the IR, I believe.
A simple version would be to just add to the Overlay Table a "autodiff" -> "autodiff_defereed"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree.
Base.Experimental.@MethodTable ENZYME_TABLE
Base.Experimental.@overlay ENZYME_TABLE function autodiff(...)
[... the implementation for the deferred autodiff ...]
end
Core.Compiler.method_table(interp::EnzymeInterpreter) =
Core.Compiler.OverlayMethodTable(get_inference_world(interp), ENZYME_TABLE)
This allows EnzymeInterpreter
to use the defferred autodiff
implementation instead of the usual autodiff
implementation automatically. I think that would work, since when an user calls autodiff(...)
it dispatches to the usual autodiff
implementation, which then kicks off the entire Enzyme compilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so one issue here is that we already receive a method table from GPUCompiler [e.g. for CUDA/etc]
Also I agree methodtable is cleaner, but the current code here does work [as also confirmed by CI]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still surprised that this works... You are presenting conflicting information to the abstract interpreter. the CallInfo says one thing, and the IR says another. This may currently "work", but you might also encounter a situation where the code uses the IR as the source of truth.
That's why in my various attempts at something similar, during inlining I replace the IR with something else.
In concrete terms, how can I figure out if autodiff deferred is needed or not? At the moment I cannot test DI on GPU but I'd like it to work |
Deferred codegen is a concept from GPUCompiler.jl. Any time you use a GPUCompiler tool within another GPUCompiler tool, the inner one needs to use deferred codegen. This happens when calling Enzyme from within Enzyme (e.g. higher order AD), or Enzyme from within CUDA. What we did here is to fix any use of Enzyme.autodiff from within an Enzyme.autodiff to become deferred for the inner call. I imagine a similar fix to CUDA.jl's interpreter would also remedy |
@vchuravy @aviatesk I'm going to merge this for now since we're about to get a bunch of other breaking things which conflict and it at least runs, if not necessarily optimally. I tried to get the MethodTable to work, but failed. Do you have any code examples for the nested methodtable needed per above, and we can open a follow up PR (which also notably wouldn't be as urgent since this wouldn't be a breaking change semantically then). |
Would the fix for CUDA have to be Enzyme-specific, like a package extension? If so, wouldn't it be easier to fix it here? |
We cannot override CUDA.jl's abstract interpreter or method table in Enzyme.jl, that would be akin to (but arguably worse than) type piracy. Yeah probably a package extension that does whatever the end solution here arrives at. |
A version of #1443 just focusing on doing the autodiff -> autodiff_deferred and not the improved effects/callinfo