diff --git a/examples/autodiff.jl b/examples/autodiff.jl index 669f3b6809..6bd0b74fb5 100644 --- a/examples/autodiff.jl +++ b/examples/autodiff.jl @@ -98,7 +98,7 @@ dby = [0.0] Enzyme.autodiff( Forward, - (x,y) -> Enzyme.autodiff_deferred(Reverse, f, x, y), + (x,y) -> Enzyme.autodiff(Reverse, f, x, y), Duplicated(Duplicated(x, bx), Duplicated(dx, dbx)), Duplicated(Duplicated(y, by), Duplicated(dy, dby)), ) @@ -121,7 +121,7 @@ dbx[2] == 1.0 # \end{aligned} # ``` function grad(x, dx, y, dy) - Enzyme.autodiff_deferred(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) + Enzyme.autodiff(Reverse, f, Duplicated(x, dx), DuplicatedNoNeed(y, dy)) nothing end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index bb86a33fc7..583035593d 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1084,31 +1084,6 @@ grad = gradient(Reverse, only ∘ f, (a = 2.0, b = [3.0], c = "str")) end end -""" - gradient_deferred(::ReverseMode, f, x) - -Like [`gradient`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, f::F, x::X) where {F, X, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState - dx = Ref(make_zero(x)) - autodiff_deferred(rm, f, Active, MixedDuplicated(x, dx)) - if ReturnPrimal - return (only(dx), res[2]) - else - return only(dx) - end - else - dx = make_zero(x) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - if ReturnPrimal - (dx, res[2]) - else - dx - end - end -end - """ gradient!(::ReverseMode, dx, f, x) @@ -1149,22 +1124,6 @@ gradient!(ReverseWithPrimal, dx, f, [2.0, 3.0]) end end - -""" - gradient_deferred!(::ReverseMode, f, x) - -Like [`gradient!`](@ref), except it using deferred mode. -""" -@inline function gradient_deferred!(rm::ReverseMode{ReturnPrimal,RuntimeActivity,ABI,Holomorphic,ErrIfFuncWritten}, dx::X, f::F, x::X) where {X<:Array, F, ReturnPrimal, RuntimeActivity, ABI, Holomorphic, ErrIfFuncWritten} - make_zero!(dx) - autodiff_deferred(rm, f, Active, Duplicated(x, dx)) - return if ReturnPrimal - (dx, res[2]) - else - dx - end -end - """ gradient(::ForwardMode, f, x; shadow=onehot(x)) @@ -1605,7 +1564,7 @@ res """ @inline function hvp!(res::X, f::F, x::X, v::X) where {F, X} grad = make_zero(x) - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), DuplicatedNoNeed(grad, res), Const(f), Duplicated(x, v)) return nothing end @@ -1640,7 +1599,7 @@ grad ``` """ @inline function hvp_and_gradient!(res::X, grad::X, f::F, x::X, v::X) where {F, X} - Enzyme.autodiff(Forward, gradient_deferred!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) + Enzyme.autodiff(Forward, gradient!, Const(Reverse), Duplicated(grad, res), Const(f), Duplicated(x, v)) return nothing end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2ef66a1571..482690e20f 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -212,4 +212,34 @@ let # overload `inlining_policy` end end -end # module Interpreter +import Core.Compiler: abstract_call, abstract_call_known, ArgInfo, StmtInfo, AbsIntState, get_max_methods, + CallMeta, Effects, NoCallInfo, widenconst, mapany + +struct AutodiffCallInfo <: CallInfo + # ... + info::CallInfo +end + +function abstract_call_known(interp::EnzymeInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = get_max_methods(interp, f, sv)) + + (; fargs, argtypes) = arginfo + + if f === Enzyme.autodiff && length(argtypes) >= 4 + if widenconst(argtypes[2]) <: Enzyme.Mode && widenconst(argtypes[3]) <: Enzyme.Annotation && widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : [:(Enzyme.autodiff_deferred), fargs[2:end]...], + [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...] + ) + return abstract_call_known( + interp, Enzyme.autodiff_deferred, arginfo2, + si, sv, max_methods) + end + end + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, f, arginfo::ArgInfo, + si::StmtInfo, sv::AbsIntState, max_methods::Int) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 18d765938d..b079c0f540 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -486,6 +486,14 @@ end end +@testset "Deferred upgrade" begin + function gradsin(x) + return gradient(Reverse, sin, x) + end + res = Enzyme.gradient(Reverse, gradsin, 3.1) + @test res ≈ -sin(3.1) +end + @testset "Simple Complex tests" begin mul2(z) = 2 * z square(z) = z * z