diff --git a/docs/src/index.md b/docs/src/index.md index 5da76ccb96..61e918f4dd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -312,3 +312,183 @@ da2 Sometimes, determining how to perform this zeroing can be complicated. That is why Enzyme provides a helper function `Enzyme.make_zero` that does this automatically. + +### Complex Numbers + +Differentiation of a function which returns a complex number is ambiguous, because there are several different gradients which may be desired. Rather than assume a specific of these conventions and potentially result in user error when the resulting derivative is not the desired one, Enzyme forces users to specify the desired convention by returning a real number instead. + +Consider the function `f(z) = z*z`. If we were to differentiate this and have real inputs and outputs, the derivative `f'(z)` would be unambiguously `2*z`. However, consider breaking down a complex number down into real and imaginary parts. Suppose now we were to call `f` with the explicit real and imaginary components, `z = x + i y`. This means that `f` is a function that takes an input of two values and returns two values `f(x, y) = u(x, y) + i v(x, y)`. In the case of `z*z` this means that `u(x,y) = x*x-y*y` and `v(x,y) = 2*x*y`. + + +If we were to look at all first-order derivatives in total, we would end up with a 2x2 matrix (i.e. Jacobian), the derivative of each output wrt each input. Let's try to compute this, first by hand, then with Enzyme. + +``` +grad u(x, y) = [d/dx u, d/dy u] = [d/dx x*x-y*y, d/dy x*x-y*y] = [2*x, -2*y]; +grad v(x, y) = [d/dx v, d/dy v] = [d/dx 2*x*y, d/dy 2*x*y] = [2*y, 2*x]; +``` + +Reverse mode differentiation computes the derivative of all inputs with respect to a single output by propagating the derivative of the return to its inputs. Here, we can explicitly differentiate with respect to the real and imaginary results, respectively, to find this matrix. + +```jldoctest complex +f(z) = z * z + +# a fixed input to use for testing +z = 3.1 + 2.7im + +grad_u = Enzyme.autodiff(Reverse, z->real(f(z)), Active, Active(z))[1][1] +grad_v = Enzyme.autodiff(Reverse, z->imag(f(z)), Active, Active(z))[1][1] + +(grad_u, grad_v) +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +This is somewhat inefficient, since we need to call the forward pass twice, once for the real part, once for the imaginary. We can solve this using batched derivatives in Enzyme, which computes several derivatives for the same function all in one go. To make it work, we're going to need to use split mode, which allows us to provide a custom derivative return value. + +```jldoctest complex +fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(f)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differntial return of 1.0 + 0.0im +grad_u = rev(Const(f), Active(z), 1.0 + 0.0im, fwd(Const(f), Active(z))[1])[1][1] +# Compute the reverse pass seeded with a differntial return of 0.0 + 1.0im +grad_v = rev(Const(f), Active(z), 0.0 + 1.0im, fwd(Const(f), Active(z))[1])[1][1] + +(grad_u, grad_v) + +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +Now let's make this batched + +```jldoctest complex +fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(f)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go! +rev(Const(f), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), fwd(Const(f), Active(z))[1])[1][1] + +# output +(6.2 - 5.4im, 5.4 + 6.2im) +``` + +In contrast, Forward mode differentiation computes the derivative of all outputs with respect to a single input by providing a differential input. Thus we need to seed the shadow input with either 1.0 or 1.0im, respectively. This will compute the transpose of the matrix we found earlier. + +``` +d/dx f(x, y) = d/dx [u(x,y), v(x,y)] = d/dx [x*x-y*y, 2*x*y] = [ 2*x, 2*y]; +d/dy f(x, y) = d/dy [u(x,y), v(x,y)] = d/dy [x*x-y*y, 2*x*y] = [-2*y, 2*x]; +``` + +```jldoctest complex +d_dx = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1] +d_dy = Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1] + +(d_dx, d_dy) + +# output +(6.2 + 5.4im, -5.4 + 6.2im) +``` + +Again, we can go ahead and batch this. +```jldoctest complex +Enzyme.autodiff(Forward, f, BatchDuplicated(z, (1.0+0.0im, 0.0+1.0im)))[1] + +# output +(var"1" = 6.2 + 5.4im, var"2" = -5.4 + 6.2im) +``` + +Taking Jacobians with respect to the real and imaginary results is fine, but for a complex scalar function it would be really nice to have a single complex derivative. More concretely, in this case when differentiating `z*z`, it would be nice to simply return `2*z`. However, there are four independent variables in the 2x2 jacobian, but only two in a complex number. + +Complex differentiation is often viewed in the lens of directional derivatives. For example, what is the derivative of the function as the real input increases, or as the imaginary input increases. Consider the derivative along the real axis, $\texttt{lim}_{\Delta x \rightarrow 0} \frac{f(x+\Delta x, y)-f(x, y)}{\Delta x}$. This simplifies to $\texttt{lim}_{\Delta x \rightarrow 0} \frac{u(x+\Delta x, y)-u(x, y) + i \left[ v(x+\Delta x, y)-v(x, y)\right]}{\Delta x} = \frac{\partial}{\partial x} u(x,y) + i\frac{\partial}{\partial x} v(x,y)$. This is exactly what we computed by seeding forward mode with a shadow of `1.0 + 0.0im`. + +For completeness, we can also consider the derivative along the imaginary axis $\texttt{lim}_{\Delta y \rightarrow 0} \frac{f(x, y+\Delta y)-f(x, y)}{i\Delta y}$. Here this simplifies to $\texttt{lim}_{u(x, y+\Delta y)-u(x, y) + i \left[ v(x, y+\Delta y)-v(x, y)\right]}{i\Delta y} = -i\frac{\partial}{\partial y} u(x,y) + \frac{\partial}{\partial y} v(x,y)$. Except for the $i$ in the denominator of the limit, this is the same as the result of Forward mode, when seeding x with a shadow of `0.0 + 1.0im`. We can thus compute the derivative along the real axis by multiplying our second Forward mode call by `-im`. + +```jldoctest complex +d_real = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1] +d_im = -im * Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1] + +(d_real, d_im) + +# output +(6.2 + 5.4im, 6.2 + 5.4im) +``` + +Interestingly, the derivative of `z*z` is the same when computed in either axis. That is because this function is part of a special class of functions that are invariant to the input direction, called holomorphic. + +Thus, for holomorphic functions, we can simply seed Forward-mode AD with a shadow of one for whatever input we are differenitating. This is nice since seeding the shadow with an input of one is exactly what we'd do for real-valued funtions as well. + +Reverse-mode AD, however, is more tricky. This is because holomorphic functions are invariant to the direction of differentiation (aka the derivative inputs), not the direction of the differential return. + +However, if a function is holomorphic, the two derivative functions we computed above must be the same. As a result, $\frac{\partial}{\partial x} u = \frac{\partial}{\partial y} v$ and $\frac{\partial}{\partial y} u = -\frac{\partial}{\partial x} v$. + +We saw earlier, that performing reverse-mode AD with a return seed of `1.0 + 0.0im` yielded `[d/dx u, d/dy u]`. Thus, for a holomorphic function, a real-seeded Reverse-mode AD computes `[d/dx u, -d/dx v]`, which is the complex conjugate of the derivative. + + +```jldoctest complex +conj(grad_u) + +# output + +6.2 + 5.4im +``` + +In the case of a scalar-input scalar-output function, that's sufficient. However, most of the time one uses reverse mode, it involves either several inputs or outputs, perhaps via memory. This case requires additional handling to properly sum all the partial derivatives from the use of each input and apply the conjugate operator at only the ones relevant to the differential return. + +For simplicity, Enzyme provides a helper utlity `ReverseHolomorphic` which performs Reverse mode properly here, assuming that the function is indeed holomorphic and thus has a well-defined single derivative. + +```jldoctest complex +Enzyme.autodiff(ReverseHolomorphic, f, Active, Active(z))[1][1] + +# output + +6.2 + 5.4im +``` + +For even non-holomorphic functions, complex analysis allows us to define $\frac{\partial}{\partial z} = \frac{1}{2}\left(\frac{\partial}{\partial x} - i \frac{\partial}{\partial y} \right)$. For non-holomorphic functions, this allows us to compute `d/dz`. Let's consider `myabs2(z) = z * conj(z)`. We can compute the derivative wrt z of this in Forward mode as follows, which as one would expect results in a result of `conj(z)`: + +```jldoctest complex +myabs2(z) = z * conj(z) + +dabs2_dx, dabs2_dy = Enzyme.autodiff(Forward, myabs2, BatchDuplicated(z, (1.0 + 0.0im, 0.0 + 1.0im)))[1] +(dabs2_dx - im * dabs2_dy) / 2 + +# output + +3.1 - 2.7im +``` + +Similarly, we can compute `d/d conj(z) = d/dx + i d/dy`. + +```jldoctest complex +(dabs2_dx + im * dabs2_dy) / 2 + +# output + +3.1 + 2.7im +``` + +Computing this in Reverse mode is more tricky. Let's expand `f` in terms of `u` and `v`. $\frac{\partial}{\partial z} f = \frac12 \left( [u_x + i v_x] - i [u_y + i v_y] \right) = \frac12 \left( [u_x + v_y] + i [v_x - u_y] \right)$. Thus `d/dz = (conj(grad_u) + im * conj(grad_v))/2`. + +```jldoctest complex +abs2_fwd, abs2_rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(myabs2)}, Active, Active{ComplexF64}) + +# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go! +gradabs2_u, gradabs2_v = abs2_rev(Const(myabs2), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), abs2_fwd(Const(myabs2), Active(z))[1])[1][1] + +(conj(gradabs2_u) + im * conj(gradabs2_v)) / 2 + +# output + +3.1 - 2.7im +``` + +For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \frac12 \left( [u_x - v_y] + i [v_x + u_y] \right)$. Thus `d/d conj(z) = (grad_u + im * grad_v)/2`. + +```jldoctest complex +(gradabs2_u + im * gradabs2_v) / 2 + +# output + +3.1 + 2.7im +``` + +Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space). \ No newline at end of file diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index cad5466a10..ad6e42e371 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.6.5" +version = "0.7.0" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index afccdf6744..56c200cd61 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -1,7 +1,7 @@ module EnzymeCore export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal -export ReverseSplitModified, ReverseSplitWidth +export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed export DefaultABI, FFIABI, InlineABI export BatchDuplicatedFunc @@ -49,6 +49,7 @@ struct Active{T} <: Annotation{T} end Active(i::Integer) = Active(float(i)) +Active(ci::Complex{T}) where T <: Integer = Active(float(ci)) """ Duplicated(x, ∂f_∂x) @@ -178,14 +179,18 @@ Abstract type for what differentiation mode will be used. abstract type Mode{ABI} end """ - struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI} + struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} Reverse mode differentiation. - `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward. -""" -struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI} end -const Reverse = ReverseMode{false,DefaultABI}() -const ReverseWithPrimal = ReverseMode{true,DefaultABI}() +- `ABI`: What runtime ABI to use +- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz +""" +struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} end +const Reverse = ReverseMode{false,DefaultABI, false}() +const ReverseWithPrimal = ReverseMode{true,DefaultABI, false}() +const ReverseHolomorphic = ReverseMode{false,DefaultABI, true}() +const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true}() """ struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI} diff --git a/lib/EnzymeTestUtils/Project.toml b/lib/EnzymeTestUtils/Project.toml index 373b45cf85..2e9f3f1ae6 100644 --- a/lib/EnzymeTestUtils/Project.toml +++ b/lib/EnzymeTestUtils/Project.toml @@ -13,8 +13,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ConstructionBase = "1.4.1" -Enzyme = "0.11" -EnzymeCore = "0.5, 0.6" +Enzyme = "0.11, 0.12" +EnzymeCore = "0.5, 0.6, 0.7" FiniteDifferences = "0.12.12" MetaTesting = "0.1" Quaternions = "0.7" diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 759fc5d50e..f22411a8bd 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -2,8 +2,8 @@ module Enzyme import EnzymeCore -import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode -export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode +import EnzymeCore: Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal +export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal, ReverseSplitModified, ReverseSplitWidth, ReverseMode, ForwardMode, ReverseHolomorphic, ReverseHolomorphicWithPrimal import EnzymeCore: Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, ABI, DefaultABI, FFIABI, InlineABI export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, DefaultABI, FFIABI, InlineABI @@ -48,30 +48,14 @@ include("internal_rules.jl") import .Compiler: CompilationException -# @inline annotate() = () -# @inline annotate(arg::A, args::Vararg{Any, N}) where {A<:Annotation, N} = (arg, annotate(args...)...) -# @inline annotate(arg, args::Vararg{Any, N}) where N = (Const(arg), annotate(args...)...) - -@inline function falses_from_args(::Val{add}, args::Vararg{Any, N}) where {add,N} - ntuple(Val(add+N)) do i - Base.@_inline_meta - false - end -end - -@inline function annotate(args::Vararg{Any, N}) where N +@inline function falses_from_args(N) ntuple(Val(N)) do i Base.@_inline_meta - arg = @inbounds args[i] - if arg isa Annotation - return arg - else - return Const(arg) - end + false end end -@inline function any_active(args::Vararg{Any, N}) where N +@inline function any_active(args::Vararg{Annotation, N}) where N any(ntuple(Val(N)) do i Base.@_inline_meta arg = @inbounds args[i] @@ -118,7 +102,7 @@ end end """ - autodiff(::ReverseMode, f, Activity, args...) + autodiff(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using reverse mode. @@ -135,7 +119,7 @@ on. Enzyme will only differentiate in respect to arguments that are wrapped in an [`Active`](@ref) (for arguments whose derivative result must be returned rather than mutated in place, such as primitive types and structs thereof) or [`Duplicated`](@ref) (for mutable arguments like arrays, `Ref`s and structs -thereof). Non-annotated arguments will automatically be treated as [`Const`](@ref). +thereof). `Activity` is the Activity of the return value, it may be `Const` or `Active`. @@ -147,7 +131,7 @@ b = [2.2, 3.3]; ∂f_∂b = zero(b) c = 55; d = 9 f(a, b, c, d) = a * √(b[1]^2 + b[2]^2) + c^2 * d^2 -∂f_∂a, _, _, ∂f_∂d = autodiff(Reverse, f, Active, Active(a), Duplicated(b, ∂f_∂b), c, Active(d))[1] +∂f_∂a, _, _, ∂f_∂d = autodiff(Reverse, f, Active, Active(a), Duplicated(b, ∂f_∂b), Const(c), Active(d))[1] # output @@ -177,76 +161,153 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) [`Active`](@ref) will automatically convert plain integers to floating point values, but cannot do so for integer values in tuples and structs. """ -@inline function autodiff(::ReverseMode{ReturnPrimal, RABI}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} +@inline function autodiff(::ReverseMode{ReturnPrimal, RABI,Holomorphic}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, RABI<:ABI, Nargs,Holomorphic} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) + rt = if A isa UnionAll + Core.Compiler.return_type(f.val, tt) + else + eltype(A) + end + if A <: Active - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - rt = Core.Compiler.return_type(f.val, tt) if !allocatedinline(rt) || rt isa Union forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) - res = forward(f, args′...) + res = forward(f, args...) tape = res[1] if ReturnPrimal - return (adjoint(f, args′..., tape)[1], res[2]) + return (adjoint(f, args..., tape)[1], res[2]) else - return adjoint(f, args′..., tape) + return adjoint(f, args..., tape) end end elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed || A <: BatchDuplicatedFunc throw(ErrorException("Duplicated Returns not yet handled")) end + + if A <: Active && rt <: Complex + if Holomorphic + seen = IdDict() + seen2 = IdDict() + + f = if f isa Const || f isa Active + f + elseif f isa Duplicated || f isa DuplicatedNoNeed + BatchDuplicated(f.val, (f.dval, make_zero(typeof(f), seen, f.dval), make_zero(typeof(f), seen2, f.dval))) + else + throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + end + + args = ntuple(Val(Nargs)) do i + Base.@_inline_meta + arg = args[i] + if arg isa Const || arg isa Active + arg + elseif arg isa Duplicated || arg isa DuplicatedNoNeed + RT = eltype(Core.Typeof(arg)) + BatchDuplicated(arg.val, (arg.dval, make_zero(RT, seen, arg.dval), make_zero(RT, seen2, arg.dval))) + else + throw(ErrorException("Active Complex return does not yet support batching in combined reverse mode")) + end + end + width = same_or_one_rec(3, args...) + tt′ = Tuple{map(Core.Typeof, args)...} + + thunk = Enzyme.Compiler.thunk(Val(world), typeof(f), A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + + results = thunk(f, args..., (rt(0), rt(1), rt(im))) + + @inline function refn(x::T) where T + if T <: Complex + return conj(x) / 2 + else + return x + end + end + + @inline function imfn(x::T) where T + if T <: Complex + return im * conj(x) / 2 + else + return T(0) + end + end + + # compute the correct complex derivative in reverse mode by propagating the conjugate return values + # then subtracting twice the imaginary component to get the correct result + + for (k, v) in seen + Compiler.recursive_accumulate(k, v, refn) + end + for (k, v) in seen2 + Compiler.recursive_accumulate(k, v, imfn) + end + + fused = ntuple(Val(Nargs)) do i + Base.@_inline_meta + if args[i] isa Active + Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn), results[1][i][3], imfn) + else + results[1][i] + end + end + + return (fused, results[2:end]...) + end + + throw(ErrorException("Reverse-mode Active Complex return is ambiguous and requires more information to specify the desired result. See https://enzyme.mit.edu/julia/stable/#Complex for more details.")) + end + thunk = Enzyme.Compiler.thunk(Val(world), FA, A, tt′, #=Split=# Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) + if A <: Active - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} - rt = Core.Compiler.return_type(f.val, tt) - args′ = (args′..., one(rt)) + args = (args..., Compiler.default_adjoint(rt)) end - thunk(f, args′...) + thunk(f, args...) end """ - autodiff(mode::Mode, f, ::Type{A}, args...) + autodiff(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) Like [`autodiff`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff(mode::CMode, f::F, args...) where {F, CMode<:Mode} +@inline function autodiff(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} autodiff(mode, Const(f), args...) end +@inline function autodiff(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} + autodiff(mode, Const(f), RT, args...) +end """ - autodiff(mode::Mode, f, args...) + autodiff(mode::Mode, f, args::Vararg{Annotation, Nargs}) Like [`autodiff`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff(mode::CMode, f::FA, args...) where {FA<:Annotation, CMode<:Mode} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} +@inline function autodiff(mode::CMode, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, CMode<:Mode, Nargs} + tt′ = Tuple{map(Core.Typeof, args)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} rt = Core.Compiler.return_type(f.val, tt) A = guess_activity(rt, mode) - autodiff(mode, f, A, args′...) + autodiff(mode, f, A, args...) end """ - autodiff(::ForwardMode, f, Activity, args...) + autodiff(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) Auto-differentiate function `f` at arguments `args` using forward mode. `args` may be numbers, arrays, structs of numbers, structs of arrays and so on. Enzyme will only differentiate in respect to arguments that are wrapped -in a [`Duplicated`](@ref) or similar argument. Non-annotated arguments will -automatically be treated as [`Const`](@ref). Unlike reverse mode in +in a [`Duplicated`](@ref) or similar argument. Unlike reverse mode in [`autodiff`](@ref), [`Active`](@ref) arguments are not allowed here, since all derivative results of immutable objects will be returned and should instead use [`Duplicated`](@ref) or variants like [`DuplicatedNoNeed`](@ref). @@ -284,13 +345,12 @@ f(x) = x*x (6.28,) ``` """ -@inline function autodiff(::ForwardMode{RABI}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI} - args′ = annotate(args...) - if any_active(args′...) +@inline function autodiff(::ForwardMode{RABI}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation} where {RABI <: ABI, Nargs} + if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args′)...} - width = same_or_one(args′...) + tt′ = Tuple{map(Core.Typeof, args)...} + width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end @@ -314,30 +374,29 @@ f(x) = x*x A end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) thunk = Enzyme.Compiler.thunk(Val(world), FA, RT, tt′, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(::ReverseMode, f, Activity, args...) + autodiff_deferred(::ReverseMode, f, Activity, args::Vararg{Annotation, Nargs}) Same as [`autodiff`](@ref) but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal} - args′ = annotate(args...) - tt′ = Tuple{map(Core.Typeof, args′)...} +@inline function autodiff_deferred(::ReverseMode{ReturnPrimal}, f::FA, ::Type{A}, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal, Nargs} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) @@ -353,31 +412,30 @@ code, as well as high-order differentiation. error("Return type inferred to be Union{}. Giving up.") end - ModifiedBetween = Val(falses_from_args(Val(1), args...)) - + ModifiedBetween = Val(falses_from_args(Nargs+1)) + adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) @assert primal_ptr === nothing thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr) if rt <: Active - args′ = (args′..., one(eltype(rt))) + args = (args..., Compiler.default_adjoint(eltype(rt))) elseif A <: Duplicated || A<: DuplicatedNoNeed || A <: BatchDuplicated || A<: BatchDuplicatedNoNeed throw(ErrorException("Duplicated Returns not yet handled")) end - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(::ForwardMode, f, Activity, args...) + autodiff_deferred(::ForwardMode, f, Activity, args::Vararg{Annotation, Nargs}) -Same as `autodiff(::ForwardMode, ...)` but uses deferred compilation to support usage in GPU +Same as `autodiff(::ForwardMode, f, Activity, args)` but uses deferred compilation to support usage in GPU code, as well as high-order differentiation. """ -@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation} - args′ = annotate(args...) - if any_active(args′...) +@inline function autodiff_deferred(::ForwardMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}) where {FA<:Annotation, A<:Annotation, Nargs} + if any_active(args...) throw(ErrorException("Active arguments not allowed in forward mode")) end - tt′ = Tuple{map(Core.Typeof, args′)...} + tt′ = Tuple{map(Core.Typeof, args)...} width = same_or_one(args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -397,7 +455,7 @@ code, as well as high-order differentiation. else A end - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) @@ -418,43 +476,45 @@ code, as well as high-order differentiation. end ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated) - ModifiedBetween = Val(falses_from_args(Val(1), args...)) - + ModifiedBetween = Val(falses_from_args(Nargs+1)) adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) @assert primal_ptr === nothing thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr) - thunk(f, args′...) + thunk(f, args...) end """ - autodiff_deferred(mode::Mode, f, ::Type{A}, args...) + autodiff_deferred(mode::Mode, f, ::Type{A}, args::Vararg{Annotation, Nargs}) Like [`autodiff_deferred`](@ref) but will try to extend f to an annotation, if needed. """ -@inline function autodiff_deferred(mode::CMode, f::F, args...) where {F, CMode<:Mode} +@inline function autodiff_deferred(mode::CMode, f::F, args::Vararg{Annotation, Nargs}) where {F, CMode<:Mode, Nargs} autodiff_deferred(mode, Const(f), args...) end +@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Vararg{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs} + autodiff_deferred(mode, Const(f), RT, args...) +end + """ - autodiff_deferred(mode, f, args...) + autodiff_deferred(mode, f, args::Vararg{Annotation, Nargs}) Like [`autodiff_deferred`](@ref) but will try to guess the activity of the return value. """ -@inline function autodiff_deferred(mode::M, f::FA, args...) where {FA<:Annotation, M<:Mode} - args′ = annotate(args...) - tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...} +@inline function autodiff_deferred(mode::M, f::FA, args::Vararg{Annotation, Nargs}) where {FA<:Annotation, M<:Mode, Nargs} + tt = Tuple{map(T->eltype(Core.Typeof(T)), args)...} world = codegen_world_age(Core.Typeof(f.val), tt) rt = Core.Compiler.return_type(f.val, tt) if rt === Union{} error("return type is Union{}, giving up.") end rt = guess_activity(rt, mode) - autodiff_deferred(mode, f, rt, args′...) + autodiff_deferred(mode, f, rt, args...) end """ - autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes...) + autodiff_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -496,8 +556,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI} - # args′ = annotate(args...) +@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs} width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -509,7 +568,7 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -525,7 +584,7 @@ result, ∂v, ∂A end """ - autodiff_thunk(::ForwardMode, ftype, Activity, argtypes...) + autodiff_thunk(::ForwardMode, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the thunk forward mode function for annotated function type ftype when called with args of type `argtypes`. @@ -568,8 +627,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated (6.28,) ``` """ -@inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, RABI<:ABI} - # args′ = annotate(args...) +@inline function autodiff_thunk(::ForwardMode{RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, RABI<:ABI, Nargs} width = same_or_one(A, args...) if width == 0 throw(ErrorException("Cannot differentiate with a batch size of 0")) @@ -578,7 +636,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated throw(ErrorException("Active Returns not allowed in forward mode")) end ReturnPrimal = Val(A <: Duplicated || A <: BatchDuplicated) - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) tt = Tuple{map(eltype, args)...} @@ -587,8 +645,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI) end -@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} - # args′ = annotate(args...) +@inline function tape_type(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -600,7 +657,7 @@ end end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -616,7 +673,7 @@ end end """ - autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes...) + autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) Provide the split forward and reverse pass functions for annotated function type ftype when called with args of type `argtypes` when using reverse mode. @@ -658,9 +715,8 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} +@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI - # args′ = annotate(args...) width = if Width == 0 w = same_or_one(args...) if w == 0 @@ -672,7 +728,7 @@ result, ∂v, ∂A end if ModifiedBetweenT === true - ModifiedBetween = Val(falses_from_args(Val(1), args...)) + ModifiedBetween = Val(falses_from_args(Nargs+1)) else ModifiedBetween = Val(ModifiedBetweenT) end @@ -1017,7 +1073,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2)) j = 0 for shadow in res[3] j += 1 - @inbounds shadow[(i-1)*chunk+j] += one(eltype(typeof(shadow))) + @inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow))) end (i == num ? adjoint2 : adjoint)(Const(f), BatchDuplicated(x, dx), tape) return dx @@ -1040,7 +1096,7 @@ end dx = zero(x) res = primal(Const(f), Duplicated(x, dx)) tape = res[1] - @inbounds res[3][i] += one(eltype(typeof(res[3]))) + @inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3]))) adjoint(Const(f), Duplicated(x, dx), tape) return dx end diff --git a/src/compiler.jl b/src/compiler.jl index 158f1ddc26..21a0e85556 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -379,7 +379,7 @@ end return AnyState end - if T <: Complex + if T <: Complex && !(T isa UnionAll) return active_reg_inner(ptreltype(T), seen, world, Val(justActive), Val(UnionSret)) end @@ -5024,28 +5024,88 @@ end end end -@inline function recursive_add(x::T, y::T) where T - if guaranteed_const(T) +# Recursively return x + f(y), where y is active, otherwise x + +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T, F, F2} + if forcelhs(T) return x end splatnew(T, ntuple(Val(fieldcount(T))) do i Base.@_inline_meta prev = getfield(x, i) next = getfield(y, i) - recursive_add(prev, next) + recursive_add(prev, next, f, forcelhs) end) end -@inline function recursive_add(x::T, y::T) where {T<:AbstractFloat} - return x + y +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:AbstractFloat, F, F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline function recursive_add(x::T, y::T, f::F=identity, forcelhs::F2=guaranteed_const) where {T<:Complex, F, F2} + if forcelhs(T) + return x + end + return x + f(y) +end + +@inline mutable_register(::Type{T}) where T <: Integer = true +@inline mutable_register(::Type{T}) where T <: AbstractFloat = false +@inline mutable_register(::Type{Complex{T}}) where T <: AbstractFloat = false +@inline mutable_register(::Type{T}) where T <: Tuple = false +@inline mutable_register(::Type{T}) where T <: NamedTuple = false +@inline mutable_register(::Type{Core.Box}) = true +@inline mutable_register(::Type{T}) where T <: Array = true +@inline mutable_register(::Type{T}) where T = ismutable(T) + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F=identity) where {T, F} + if !mutable_register(T) + for I in eachindex(x) + prev = x[I] + @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) + end + end +end + + +# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) +@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F=identity) where {F} + recursive_accumulate(x.contents, y.contents, seen, f) +end + +@inline function recursive_accumulate(x::T, y::T, f::F=identity) where {T, F} + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + + for i in 1:nf + if isdefined(x, i) + xi = getfield(x, i) + ST = Core.Typeof(xi) + if !mutable_register(ST) + @assert ismutable(x) + yi = getfield(y, i) + nexti = recursive_add(xi, yi, f, mutable_register) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, nexti) + end + end + end end +@inline default_adjoint(::Type{T}) where T = error("Active return values with automatic pullback (differential return value) deduction only supported for floating-like values and not type $T. If mutable memory, please use Duplicated. Otherwise, you can explicitly specify a pullback by using split mode, e.g. autodiff_thunk(ReverseSplitWithPrimal, ...)") +@inline default_adjoint(::Type{T}) where T<:AbstractFloat = one(T) +@inline default_adjoint(::Type{Complex{T}}) where T = error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") + function add_one_in_place(x) ty = typeof(x) # ptr = Base.pointer_from_objref(x) ptr = unsafe_to_pointer(x) if ty <: Base.RefValue || ty == Base.RefValue{Float64} - x[] = recursive_add(x[], one(eltype(ty))) + x[] = recursive_add(x[], default_adjoint(eltype(ty))) else error("Enzyme Mutability Error: Cannot add one in place to immutable value "*string(x)) end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index c3f9f0bcdc..a2900b3356 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -77,6 +77,11 @@ else # WorldOverlayMethodTable(interp.world) end +function is_alwaysinline_func(@nospecialize(TT)) + isa(TT, DataType) || return false + return false +end + function is_primitive_func(@nospecialize(TT)) isa(TT, DataType) || return false ft = TT.parameters[1] @@ -88,6 +93,13 @@ function is_primitive_func(@nospecialize(TT)) return true end end + + if ft == typeof(Base.inv) + if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}} + return true + end + end + @static if VERSION >= v"1.9-" if ft === typeof(Base.rem) if TT <: Tuple{ft, Float32, Float32} || TT <: Tuple{ft, Float64, Float64} @@ -185,6 +197,12 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, return nothing end + if is_alwaysinline_func(specTypes) + @safe_debug "Forcing inlining for primitive func" mi.specTypes + @assert src !== nothing + return src + end + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) @safe_debug "Blocking inlining due to inactive rule" mi.specTypes return nothing @@ -218,6 +236,12 @@ function Core.Compiler.inlining_policy(interp::EnzymeInterpreter, if is_primitive_func(specTypes) return nothing end + + if is_alwaysinline_func(specTypes) + @assert src !== nothing + return src + end + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) return nothing end @@ -251,6 +275,11 @@ function Core.Compiler.resolve_todo(todo::InliningTodo, state::InliningState{S, if is_primitive_func(specTypes) return Core.Compiler.compileable_specialization(state.et, todo.spec.match) end + + if is_alwaysinline_func(specTypes) + @assert false "Need to mark resolve_todo function as alwaysinline, but don't know how" + end + interp = state.policy.interp method_table = Core.Compiler.method_table(interp) if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) diff --git a/test/abi.jl b/test/abi.jl index 2f30dd4c75..ef0db2fa22 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -34,18 +34,21 @@ using Test @test () === autodiff_deferred(Forward, f, Const(Int)) # Complex numbers - cres, = autodiff(Reverse, f, Active, Active(1.5 + 0.7im))[1] + @test_throws ErrorException autodiff(Reverse, f, Active, Active(1.5 + 0.7im)) + cres, = autodiff(ReverseHolomorphic, f, Active, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im cres, = autodiff(Forward, f, DuplicatedNoNeed, Duplicated(1.5 + 0.7im, 1.0 + 0im)) @test cres ≈ 1.0 + 0.0im - cres, = autodiff(Reverse, f, Active(1.5 + 0.7im))[1] + @test_throws ErrorException autodiff(Reverse, f, Active(1.5 + 0.7im)) + cres, = autodiff(ReverseHolomorphic, f, Active(1.5 + 0.7im))[1] @test cres ≈ 1.0 + 0.0im cres, = autodiff(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im - cres, = autodiff_deferred(Reverse, f, Active(1.5 + 0.7im))[1] - @test cres ≈ 1.0 + 0.0im + @test_throws ErrorException autodiff_deferred(Reverse, f, Active(1.5 + 0.7im)) + @test_throws ErrorException autodiff_deferred(ReverseHolomorphic, f, Active(1.5 + 0.7im)) + cres, = autodiff_deferred(Forward, f, Duplicated(1.5 + 0.7im, 1.0+0im)) @test cres ≈ 1.0 + 0.0im @@ -207,16 +210,16 @@ using Test @test 7*3.4 + 9 * 1.2 ≈ first(autodiff(Forward, h, Duplicated(Foo(3, 1.2), Foo(0, 7.0)), Duplicated(Foo(5, 3.4), Foo(0, 9.0)))) caller(f, x) = f(x) - _, res4 = autodiff(Reverse, caller, Active, (x)->x, Active(3.0))[1] + _, res4 = autodiff(Reverse, caller, Active, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, DuplicatedNoNeed, (x)->x, Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, DuplicatedNoNeed, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 - _, res4 = autodiff(Reverse, caller, (x)->x, Active(3.0))[1] + _, res4 = autodiff(Reverse, caller, Const((x)->x), Active(3.0))[1] @test res4 ≈ 1.0 - res4, = autodiff(Forward, caller, (x)->x, Duplicated(3.0, 1.0)) + res4, = autodiff(Forward, caller, Const((x)->x), Duplicated(3.0, 1.0)) @test res4 ≈ 1.0 struct LList @@ -257,16 +260,16 @@ using Test dy = Ref(7.0) @test 5.0*3.0 + 2.0*7.0≈ first(autodiff(Forward, mulr, DuplicatedNoNeed, Duplicated(x, dx), Duplicated(y, dy))) - _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, (x->x*x,), Active(2.0))[1] + _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const((x->x*x,)), Active(2.0))[1] @test mid ≈ 4.0 - _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, [x->x*x], Active(2.0))[1] + _, mid = Enzyme.autodiff(Reverse, (fs, x) -> fs[1](x), Active, Const([x->x*x]), Active(2.0))[1] @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, (x->x*x,), Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const((x->x*x,)), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 - mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, [x->x*x], Duplicated(2.0, 1.0)) + mid, = Enzyme.autodiff(Forward, (fs, x) -> fs[1](x), DuplicatedNoNeed, Const([x->x*x]), Duplicated(2.0, 1.0)) @test mid ≈ 4.0 @@ -373,10 +376,10 @@ end return f.x * x end - @test Enzyme.autodiff(Reverse, method, Active, AFoo(2.0), Active(3.0))[1][2] ≈ 2.0 + @test Enzyme.autodiff(Reverse, method, Active, Const(AFoo(2.0)), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, AFoo(2.0), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, AFoo(2.0), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(AFoo(2.0)), Duplicated(3.0, 1.0))[1] ≈ 2.0 @test Enzyme.autodiff(Forward, AFoo(2.0), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 struct ABar @@ -386,10 +389,10 @@ end return 2.0 * x end - @test Enzyme.autodiff(Reverse, method, Active, ABar(), Active(3.0))[1][2] ≈ 2.0 + @test Enzyme.autodiff(Reverse, method, Active, Const(ABar()), Active(3.0))[1][2] ≈ 2.0 @test Enzyme.autodiff(Reverse, ABar(), Active, Active(3.0))[1][1] ≈ 2.0 - @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, ABar(), Duplicated(3.0, 1.0))[1] ≈ 2.0 + @test Enzyme.autodiff(Forward, method, DuplicatedNoNeed, Const(ABar()), Duplicated(3.0, 1.0))[1] ≈ 2.0 @test Enzyme.autodiff(Forward, ABar(), DuplicatedNoNeed, Duplicated(3.0, 1.0))[1] ≈ 2.0 end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index fe65ce6279..b2ee39de22 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -199,7 +199,7 @@ end Forward, driver, dx, - A, + Const(A), db ) adJ[i, :] = dx.dval @@ -269,7 +269,7 @@ end Reverse, driver, dx, - A, + Const(A), db ) adJ[i, :] = db.dval diff --git a/test/rrules.jl b/test/rrules.jl index 54c7c22802..1322895924 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -164,7 +164,10 @@ function EnzymeRules.reverse( end @testset "Complex values" begin - @test Enzyme.autodiff(Enzyme.Reverse, foo, Active(1.0+3im))[1][1] ≈ 1.0+13.0im + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(foo)}, Active, Active{ComplexF64}) + z = 1.0+3im + grad_u = rev(Const(foo), Active(z), 1.0 + 0.0im, fwd(Const(foo), Active(z))[1])[1][1] + @test grad_u ≈ 1.0+13.0im end _scalar_dot(x, y) = conj(x) * y @@ -288,7 +291,7 @@ function plaquette_sum(U) p += remultr(@inbounds U[site]) end - return p + return real(p) end diff --git a/test/runtests.jl b/test/runtests.jl index 687acbc3b3..d4b541733a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,22 +26,31 @@ using Enzyme_jll # Test against FiniteDifferences function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) - ∂x, = autodiff(Reverse, f, Active, Active(x))[1] - if typeof(x) <: Complex + ∂x, = autodiff(ReverseHolomorphic, f, Active, Active(x))[1] + + finite_diff = if typeof(x) <: Complex + RT = typeof(x).parameters[1] + (fdm(dx -> f(x+dx), RT(0)) - im * fdm(dy -> f(x+im*dy), RT(0)))/2 else - @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) + fdm(f, x) end - rm = ∂x + @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + if typeof(x) <: Integer x = Float64(x) end - ∂x, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) + if typeof(x) <: Complex - @test ∂x ≈ rm + ∂re, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) + ∂im, = autodiff(Forward, f, Duplicated(x, im*one(typeof(x)))) + ∂x = (∂re - im*∂im)/2 else - @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) + ∂x, = autodiff(Forward, f, Duplicated(x, one(typeof(x)))) end + + @test isapprox(∂x, finite_diff; rtol=rtol, atol=atol, kwargs...) + end function test_matrix_to_number(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) @@ -290,11 +299,130 @@ make3() = (1.0, 2.0, 3.0) end +@testset "Simple Complex tests" begin + mul2(z) = 2 * z + square(z) = z * z + + z = 1.0+1.0im + + @test_throws ErrorException autodiff(Reverse, mul2, Active, Active(z)) + @test_throws ErrorException autodiff(ReverseWithPrimal, mul2, Active, Active(z)) + @test autodiff(ReverseHolomorphic, mul2, Active, Active(z))[1][1] ≈ 2.0 + 0.0im + @test autodiff(ReverseHolomorphicWithPrimal, mul2, Active, Active(z))[1][1] ≈ 2.0 + 0.0im + @test autodiff(ReverseHolomorphicWithPrimal, mul2, Active, Active(z))[2] ≈ 2 * z + + z = 3.4 + 2.7im + @test autodiff(ReverseHolomorphic, square, Active, Active(z))[1][1] ≈ 2 * z + @test autodiff(ReverseHolomorphic, identity, Active, Active(z))[1][1] ≈ 1 + + @test autodiff(ReverseHolomorphic, Base.inv, Active, Active(3.0 + 4.0im))[1][1] ≈ 0.0112 + 0.0384im + + mul3(z) = Base.inferencebarrier(2 * z) + + @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active, Active(z)) + @test_throws ErrorException autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sum, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 1.0 + + sumsq(x) = sum(x .* x) + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + sumsq2(x) = sum(abs2.(x)) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq2, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + sumsq2C(x) = Complex{Float64}(sum(abs2.(x))) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq2C, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 3.4 - 2.7im + + sumsq3(x) = sum(x .* conj(x)) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq3, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 3.4 - 2.7im + + sumsq3R(x) = Float64(sum(x .* conj(x))) + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, sumsq3R, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 3.4 + 2.7im + @test dvals[1] ≈ 2 * (3.4 + 2.7im) + + function setinact(z) + z[1] *= 2 + nothing + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + + function setinact2(z) + z[1] *= 2 + return 0.0+1.0im + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact2, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setinact2, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + + function setact(z) + z[1] *= 2 + return z[1] + end + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setact, Const, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 0.0 + + vals = Complex{Float64}[3.4 + 2.7im] + dvals = Complex{Float64}[0.0] + autodiff(ReverseHolomorphic, setact, Active, Duplicated(vals, dvals)) + @test vals[1] ≈ 2 * (3.4 + 2.7im) + @test dvals[1] ≈ 2.0 + + function upgrade(z) + z = ComplexF64(z) + return z*z + end + @test autodiff(ReverseHolomorphic, upgrade, Active, Active(3.1))[1][1] ≈ 6.2 +end + @testset "Simple Exception" begin f_simple_exc(x, i) = ccall(:jl_, Cvoid, (Any,), x[i]) y = [1.0, 2.0] f_x = zero.(y) - @test_throws BoundsError autodiff(Reverse, f_simple_exc, Duplicated(y, f_x), 0) + @test_throws BoundsError autodiff(Reverse, f_simple_exc, Duplicated(y, f_x), Const(0)) end @@ -653,7 +781,7 @@ end B = Float64[4.0, 5.0] dB = Float64[0.0, 0.0] f = (X, Y) -> sum(X .* Y) - Enzyme.autodiff(Reverse, f, Active, A, Duplicated(B, dB)) + Enzyme.autodiff(Reverse, f, Active, Const(A), Duplicated(B, dB)) function gc_copy(x) # Basically g(x) = x^2 a = x * ones(10) @@ -884,8 +1012,8 @@ end # @test fd ≈ first(autodiff(Forward, foo, Duplicated(x, 1))) f74(a, c) = a * √c - @test √3 ≈ first(autodiff(Reverse, f74, Active, Active(2), 3))[1] - @test √3 ≈ first(autodiff(Forward, f74, Duplicated(2.0, 1.0), 3)) + @test √3 ≈ first(autodiff(Reverse, f74, Active, Active(2), Const(3)))[1] + @test √3 ≈ first(autodiff(Forward, f74, Duplicated(2.0, 1.0), Const(3))) end @testset "SinCos" begin @@ -927,9 +1055,9 @@ mybesselj1(z) = mybesselj(1, z) @testset "Bessel" begin autodiff(Reverse, mybesselj, Active, Const(0), Active(1.0)) - autodiff(Reverse, mybesselj, Active, 0, Active(1.0)) + autodiff(Reverse, mybesselj, Active, Const(0), Active(1.0)) + autodiff(Forward, mybesselj, Const(0), Duplicated(1.0, 1.0)) autodiff(Forward, mybesselj, Const(0), Duplicated(1.0, 1.0)) - autodiff(Forward, mybesselj, 0, Duplicated(1.0, 1.0)) @testset "besselj0/besselj1" for x in (1.0, -1.0, 0.0, 0.5, 10, -17.1,) # 1.5 + 0.7im) test_scalar(mybesselj0, x, rtol=1e-5, atol=1e-5) test_scalar(mybesselj1, x, rtol=1e-5, atol=1e-5) @@ -1446,7 +1574,7 @@ end u_v_eta = [0.0] - v = autodiff(Reverse, incopy, Active, Const(u_v_eta), Active(3.14), 1)[1][2] + v = autodiff(Reverse, incopy, Active, Const(u_v_eta), Active(3.14), Const(1))[1][2] @test v ≈ 1.0 @test u_v_eta[1] ≈ 0.0 @@ -1456,7 +1584,7 @@ end return @inbounds eta[i] end - v = autodiff(Reverse, incopy2, Active, Active(3.14), 1)[1][1] + v = autodiff(Reverse, incopy2, Active, Active(3.14), Const(1))[1][1] @test v ≈ 1.0 end @@ -1490,11 +1618,11 @@ end end y end - @test 1.0 ≈ autodiff(Reverse, f_undef, false, Active(2.14))[1][2] - @test_throws Base.UndefVarError autodiff(Reverse, f_undef, true, Active(2.14)) + @test 1.0 ≈ autodiff(Reverse, f_undef, Const(false), Active(2.14))[1][2] + @test_throws Base.UndefVarError autodiff(Reverse, f_undef, Const(true), Active(2.14)) - @test 1.0 ≈ autodiff(Forward, f_undef, false, Duplicated(2.14, 1.0))[1] - @test_throws Base.UndefVarError autodiff(Forward, f_undef, true, Duplicated(2.14, 1.0)) + @test 1.0 ≈ autodiff(Forward, f_undef, Const(false), Duplicated(2.14, 1.0))[1] + @test_throws Base.UndefVarError autodiff(Forward, f_undef, Const(true), Duplicated(2.14, 1.0)) end @testset "Return GC error" begin @@ -1508,8 +1636,8 @@ end end end - @test 0.0 ≈ autodiff(Reverse, tobedifferentiated, true, Active(2.1))[1][2] - @test 0.0 ≈ autodiff(Forward, tobedifferentiated, true, Duplicated(2.1, 1.0))[1] + @test 0.0 ≈ autodiff(Reverse, tobedifferentiated, Const(true), Active(2.1))[1][2] + @test 0.0 ≈ autodiff(Forward, tobedifferentiated, Const(true), Duplicated(2.1, 1.0))[1] function tobedifferentiated2(cond, a)::Float64 if cond @@ -1519,8 +1647,8 @@ end end end - @test 1.0 ≈ autodiff(Reverse, tobedifferentiated2, true, Active(2.1))[1][2] - @test 1.0 ≈ autodiff(Forward, tobedifferentiated2, true, Duplicated(2.1, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, tobedifferentiated2, Const(true), Active(2.1))[1][2] + @test 1.0 ≈ autodiff(Forward, tobedifferentiated2, Const(true), Duplicated(2.1, 1.0))[1] @noinline function copy(dest, p1, cond) bc = convert(Broadcast.Broadcasted{Nothing}, Broadcast.instantiate(p1)) @@ -1550,8 +1678,8 @@ end F_H = [1.0, 0.0] F = [1.0, 0.0] - autodiff(Reverse, mer, Duplicated(F, L), Duplicated(F_H, L_H), true) - autodiff(Forward, mer, Duplicated(F, L), Duplicated(F_H, L_H), true) + autodiff(Reverse, mer, Duplicated(F, L), Duplicated(F_H, L_H), Const(true)) + autodiff(Forward, mer, Duplicated(F, L), Duplicated(F_H, L_H), Const(true)) end @testset "GC Sret" begin @@ -1713,8 +1841,8 @@ end -t nothing end - autodiff(Reverse, tobedifferentiated, Duplicated(F, L), false) - autodiff(Forward, tobedifferentiated, Duplicated(F, L), false) + autodiff(Reverse, tobedifferentiated, Duplicated(F, L), Const(false)) + autodiff(Forward, tobedifferentiated, Duplicated(F, L), Const(false)) end main() @@ -1946,9 +2074,9 @@ end f_union(cond, x) = cond ? x : 0 g_union(cond, x) = f_union(cond,x)*x if sizeof(Int) == sizeof(Int64) - @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, true, Active(1.0)) + @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0)) else - @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, true, Active(1.0f0)) + @test_throws Enzyme.Compiler.IllegalTypeAnalysisException autodiff(Reverse, g_union, Active, Const(true), Active(1.0f0)) end # TODO: Add test for NoShadowException end @@ -1985,7 +2113,7 @@ end; loss = Ref(0.0) dloss = Ref(1.0) - autodiff(Reverse, objective!, Duplicated(x, zero(x)), Duplicated(loss, dloss), R) + autodiff(Reverse, objective!, Duplicated(x, zero(x)), Duplicated(loss, dloss), Const(R)) @test loss[] ≈ 0.0 @show dloss[] ≈ 0.0 @@ -2000,7 +2128,7 @@ end out = Ref(0.0) dout = Ref(1.0) - @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), true)[1][1] + @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1] end struct MyFlux diff --git a/test/threads.jl b/test/threads.jl index b3c2cc17ce..5fe80916d3 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -132,9 +132,9 @@ end end y end - @test 1.0 ≈ autodiff(Reverse, thr_inactive, false, Active(2.14))[1][2] - @test 1.0 ≈ autodiff(Forward, thr_inactive, false, Duplicated(2.14, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, thr_inactive, Const(false), Active(2.14))[1][2] + @test 1.0 ≈ autodiff(Forward, thr_inactive, Const(false), Duplicated(2.14, 1.0))[1] - @test 1.0 ≈ autodiff(Reverse, thr_inactive, true, Active(2.14))[1][2] - @test 1.0 ≈ autodiff(Forward, thr_inactive, true, Duplicated(2.14, 1.0))[1] + @test 1.0 ≈ autodiff(Reverse, thr_inactive, Const(true), Active(2.14))[1][2] + @test 1.0 ≈ autodiff(Forward, thr_inactive, Const(true), Duplicated(2.14, 1.0))[1] end