Skip to content

Commit

Permalink
Random fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Dec 15, 2023
1 parent b283a5a commit 57e884e
Showing 1 changed file with 40 additions and 30 deletions.
70 changes: 40 additions & 30 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export Annotation, Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
import EnzymeCore: BatchDuplicatedFunc
export BatchDuplicatedFunc

import EnzymeCore: batch_size, get_func
import EnzymeCore: batch_size, get_func
export batch_size, get_func

import EnzymeCore: autodiff, autodiff_deferred, autodiff_thunk, autodiff_deferred_thunk, tape_type, make_zero
Expand Down Expand Up @@ -189,7 +189,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = codegen_world_age(Core.Typeof(f.val), tt)

if A <: Active
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f.val, tt)
Expand Down Expand Up @@ -313,9 +313,9 @@ f(x) = x*x
else
A
end

ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
world = codegen_world_age(Core.Typeof(f.val), tt)

Expand All @@ -338,9 +338,9 @@ code, as well as high-order differentiation.
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = codegen_world_age(Core.Typeof(f.val), tt)

if A isa UnionAll
rt = Core.Compiler.return_type(f.val, tt)
rt = A{rt}
Expand All @@ -354,7 +354,7 @@ code, as well as high-order differentiation.
end

ModifiedBetween = Val(falses_from_args(Val(1), args...))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
@assert primal_ptr === nothing
thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr)
Expand Down Expand Up @@ -398,9 +398,9 @@ code, as well as high-order differentiation.
A
end
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}

world = codegen_world_age(Core.Typeof(f.val), tt)

if RT isa UnionAll
rt = Core.Compiler.return_type(f.val, tt)
rt = RT{rt}
Expand All @@ -420,7 +420,7 @@ code, as well as high-order differentiation.
ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated)
ModifiedBetween = Val(falses_from_args(Val(1), args...))


adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
@assert primal_ptr === nothing
thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr)
Expand Down Expand Up @@ -489,7 +489,7 @@ forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Acti
tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v))
_, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1]
result, ∂v, ∂A
result, ∂v, ∂A
# output
Expand All @@ -515,9 +515,9 @@ result, ∂v, ∂A
end

tt = Tuple{map(eltype, args)...}

world = codegen_world_age(eltype(FA), tt)

if !(A <: Const)
@assert ReturnShadow
end
Expand Down Expand Up @@ -581,9 +581,9 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, DuplicatedNoNeed, Duplicated
ModifiedBetween = Val(falses_from_args(Val(1), args...))

tt = Tuple{map(eltype, args)...}

world = codegen_world_age(eltype(FA), tt)

Enzyme.Compiler.thunk(Val(world), FA, A, Tuple{args...}, #=Mode=# Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal, #=ShadowInit=#Val(false), RABI)
end

Expand All @@ -607,7 +607,7 @@ end

@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
Expand All @@ -619,7 +619,12 @@ const tape_cache = Dict{UInt, Compiler.CompileResult}()

const tape_cache_lock = ReentrantLock()

@inline function tape_type(parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args...) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI}
import .Compiler: fspec, remove_innerty, UnknownTapeType

@inline function tape_type(
parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI},
::Type{FA}, ::Type{A}, args...
) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI}
width = if Width == 0
w = same_or_one(args...)
if w == 0
Expand All @@ -638,19 +643,23 @@ const tape_cache_lock = ReentrantLock()

@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}

# world = codegen_world_age(eltype(FA), primal_tt)
world = codegen_world_age(eltype(FA), primal_tt)

mi = fspec(eltype(FA), TT)
mi = Compiler.fspec(eltype(FA), TT, world)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, ReturnPrimal, #=ShadowInit=#false, UnknownTapeType, RABI)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false))
params = Compiler.EnzymeCompilerParams(
Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width,
Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT,
ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI
)
job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false))


key = hash(hash(job), parent_job)
key = hash((job, parent_job))

# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
lock(tape_cache_lock)
Expand All @@ -659,16 +668,17 @@ const tape_cache_lock = ReentrantLock()
obj = get(tape_cache, key, nothing)
if obj === nothing

JuliaContext() do ctx
_, meta = codegen(:llvm, job; optimize=false, parent_job)
Compiler.JuliaContext() do ctx
_, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job)
end
obj = meta.TapeType
cache[key] = meta.TapeType
tape_cache[key] = meta.TapeType
end
obj
finally
unlock(tape_cache_lock)
end
return meta.TapeType
end

"""
Expand Down Expand Up @@ -707,7 +717,7 @@ forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, Const{typeof(
tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v))
_, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1]
result, ∂v, ∂A
result, ∂v, ∂A
# output
Expand Down Expand Up @@ -735,7 +745,7 @@ result, ∂v, ∂A

@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)

Expand Down Expand Up @@ -1039,7 +1049,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
@inline function jacobian(::ReverseMode{ReturnPrimal,RABI}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI}
@assert !ReturnPrimal
num = ((n_out_val + chunk - 1) ÷ chunk)

if chunk == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
Expand All @@ -1052,7 +1062,7 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
FA = Const{Core.Typeof(f)}
World = Val(nothing)
primal, adjoint = Enzyme.Compiler.thunk(Val(world), FA, BatchDuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(chunk), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI)

if num * chunk == n_out_val
last_size = chunk
primal2, adjoint2 = primal, adjoint
Expand Down

0 comments on commit 57e884e

Please sign in to comment.