Skip to content

Commit

Permalink
Merge branch 'main' into fix-cholesky
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace authored Mar 5, 2024
2 parents 91ee976 + 5e4e2ef commit c531aae
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 129 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ jobs:
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(path="lib/EnzymeCore")
Pkg.develop(path="lib/EnzymeTestUtils")
Pkg.develop(PackageSpec(path=pwd()))
Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path="lib/EnzymeTestUtils"), PackageSpec(path=pwd())])
Pkg.instantiate()'
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
Expand Down
5 changes: 1 addition & 4 deletions .github/workflows/scripts_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ jobs:
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(path="lib/EnzymeCore")
Pkg.develop(PackageSpec(path=pwd()))
Pkg.develop(path="lib/EnzymeTestUtils")
Pkg.develop(PackageSpec(path=pwd()))
Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path=pwd()), PackageSpec(path="lib/EnzymeTestUtils")])
Pkg.instantiate()'
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
Expand Down
6 changes: 3 additions & 3 deletions examples/box.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ autodiff(Reverse,
Duplicated([Tbar; Sbar], dstate_old),
Duplicated(out_now, dout_now),
Duplicated(out_old, dout_old),
parameters,
Const(parameters),
Const(10*parameters.day)
)

Expand Down Expand Up @@ -373,7 +373,7 @@ autodiff(Reverse,
Duplicated([Tbar; Sbar], dstate_old_new),
Duplicated(out_now, dout_now),
Duplicated(out_old, dout_old),
parameters,
Const(parameters),
Const(10*parameters.day)
)

Expand Down Expand Up @@ -438,7 +438,7 @@ function compute_adjoint_values(states_before_smoother, states_after_smoother, M
Duplicated(states_after_smoother[j], dstate_old),
Duplicated(zeros(6), dout_now),
Duplicated(zeros(6), dout_old),
parameters,
Const(parameters),
Const(10*parameters.day)
)

Expand Down
95 changes: 77 additions & 18 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ code, as well as high-order differentiation.

ModifiedBetween = Val(falses_from_args(Nargs+1))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))
@assert primal_ptr === nothing
adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal))

thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr)
if rt <: Active
args = (args..., Compiler.default_adjoint(eltype(rt)))
Expand Down Expand Up @@ -478,8 +478,7 @@ code, as well as high-order differentiation.
ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated)
ModifiedBetween = Val(falses_from_args(Nargs+1))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
@assert primal_ptr === nothing
adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal)
thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr)
thunk(f, args...)
end
Expand Down Expand Up @@ -672,6 +671,71 @@ end
return TapeType
end

const tape_cache = Dict{UInt, Type}()

const tape_cache_lock = ReentrantLock()

import .Compiler: fspec, remove_innerty, UnknownTapeType

@inline function tape_type(
parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI},
::Type{FA}, ::Type{A}, args...
) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI}
width = if Width == 0
w = same_or_one(args...)
if w == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
w
else
Width
end

if ModifiedBetweenT === true
ModifiedBetween = falses_from_args(Val(1), args...)
else
ModifiedBetween = ModifiedBetweenT
end

@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}

world = codegen_world_age(eltype(FA), primal_tt)

mi = Compiler.fspec(eltype(FA), TT, world)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(
Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width,
Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT,
ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI
)
job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false))


key = hash(parent_job, hash(job))

# NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean
lock(tape_cache_lock)

try
obj = get(tape_cache, key, nothing)
if obj === nothing

Compiler.JuliaContext() do ctx
_, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job)
obj = meta.TapeType
tape_cache[key] = obj
end
end
obj
finally
unlock(tape_cache_lock)
end
end

"""
autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs})
Expand Down Expand Up @@ -703,7 +767,8 @@ function f(A, v)
res
end
forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)})
TapeType = tape_type(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)})
forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, TapeType, Const{typeof(f)}, Active, Active{Float64}, Duplicated{typeof(A)}, Active{typeof(v)})
tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v))
_, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1]
Expand All @@ -715,7 +780,7 @@ result, ∂v, ∂A
(7.26, 2.2, [3.3])
```
"""
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@assert RABI == FFIABI
width = if Width == 0
w = same_or_one(args...)
Expand All @@ -735,21 +800,15 @@ result, ∂v, ∂A

@assert ReturnShadow
TT = Tuple{args...}

primal_tt = Tuple{map(eltype, args)...}
world = codegen_world_age(eltype(FA), primal_tt)

# TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior
nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI)
TapeType = EnzymeRules.tape_type(nondef[1])
A2 = Compiler.return_type(typeof(nondef[1]))

adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
AugT = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}
@assert AugT == typeof(nondef[1])
AdjT = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}
@assert AdjT == typeof(nondef[2])
AugT(primal_ptr), AdjT(adjoint_ptr)
primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType)
aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr)
adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}(adjoint_ptr)
aug_thunk, adj_thunk
end

# White lie, should be `Core.LLVMPtr{Cvoid, 0}` but that's not supported by ccallable
Expand Down
5 changes: 5 additions & 0 deletions src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ function absint(arg::LLVM.Value, partial::Bool=false)
return (false, nothing)
end
ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce)))
if ptr == C_NULL
# XXX: Is this correct?
@error "Found null pointer" arg
return (false, nothing)
end
typ = Base.unsafe_pointer_to_objref(ptr)
return (true, typ)
end
Expand Down
26 changes: 8 additions & 18 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}(
end

const nofreefns = Set{String}((
"ijl_field_index", "jl_field_index",
"julia.call", "julia.call2",
"ijl_tagged_gensym", "jl_tagged_gensym",
"ijl_array_ptr_copy", "jl_array_ptr_copy",
Expand Down Expand Up @@ -4984,7 +4985,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage)
end

return mod, (;adjointf, augmented_primalf, entry=adjointf, compiled=meta.compiled, TapeType)
use_primal = mode == API.DEM_ReverseModePrimal
entry = use_primal ? augmented_primalf : adjointf
return mod, (;adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType)
end

# Compiler result
Expand Down Expand Up @@ -5653,26 +5656,13 @@ import GPUCompiler: deferred_codegen_jobs
params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI)
job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World)

adjoint_addr, primal_addr = get_trampoline(job)
adjoint_id = Base.reinterpret(Int, pointer(adjoint_addr))
deferred_codegen_jobs[adjoint_id] = job

if primal_addr !== nothing
primal_id = Base.reinterpret(Int, pointer(primal_addr))
deferred_codegen_jobs[primal_id] = job
else
primal_id = 0
end
addr = get_trampoline(job)
id = Base.reinterpret(Int, pointer(addr))
deferred_codegen_jobs[id] = job

quote
Base.@_inline_meta
adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, adjoint_id)))
primal = if $(primal_addr !== nothing)
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, primal_id)))
else
nothing
end
adjoint, primal
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, id)))
end
end
end
Expand Down
66 changes: 20 additions & 46 deletions src/compiler/orcv1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ function __init__()
end

mutable struct CallbackContext
tag::Symbol
job::CompilerJob
stub::Symbol
l_job::ReentrantLock
addr::Ptr{Cvoid}
CallbackContext(tag, job, stub, l_job) = new(tag, job, stub, l_job, C_NULL)
CallbackContext(job, stub, l_job) = new(job, stub, l_job, C_NULL)
end

const l_outstanding = Base.ReentrantLock()
const outstanding = Dict{Symbol, Tuple{CallbackContext, Union{Nothing, CallbackContext}}}()
const outstanding = Base.IdSet{CallbackContext}()

# Setup the lazy callback for creating a module
function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid})
Expand All @@ -61,35 +60,27 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}

# 2. lookup if we are the first
lock(l_outstanding)
if haskey(outstanding, cc.tag)
ccs = outstanding[cc.tag]
delete!(outstanding, cc.tag)
if in(cc, outstanding)
delete!(outstanding, cc)
else
ccs = nothing
end
unlock(l_outstanding)

# 3. We are the second callback to run, but we raced the other one
# thus we return the addr from them.
if ccs === nothing
unlock(l_outstanding)
unlock(cc.l_job)

# 3. We are the second callback to run, but we raced the other one
# thus we return the addr from them.
@assert cc.addr != C_NULL
return UInt64(reinterpret(UInt, cc.addr))
end
unlock(l_outstanding)

cc_adjoint, cc_primal = ccs
try
thunk = Compiler._link(cc.job, Compiler._thunk(cc.job))
cc_adjoint.addr = thunk.adjoint
if cc_primal !== nothing
cc_primal.addr = thunk.primal
end
mode = cc.job.config.params.mode
use_primal = mode == API.DEM_ReverseModePrimal
cc.addr = use_primal ? thunk.primal : thunk.adjoint

# 4. Update the stub pointer to point to the recently compiled module
set_stub!(orc, string(cc_adjoint.stub), thunk.adjoint)
if cc_primal !== nothing
set_stub!(orc, string(cc_primal.stub), thunk.primal)
end
set_stub!(orc, string(cc.stub), cc.addr)
finally
unlock(cc.l_job)
end
Expand All @@ -101,37 +92,20 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}
end

function get_trampoline(job)
tag = gensym(:tag)
l_job = Base.ReentrantLock()

mode = job.config.params.mode
needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient

cc_adjoint = CallbackContext(tag, job, gensym(:adjoint), l_job)
if needs_augmented_primal
cc_primal = CallbackContext(tag, job, gensym(:primal), l_job)
else
cc_primal = nothing
end
lock(l_outstanding) do
outstanding[tag] = (cc_adjoint, cc_primal)
end
cc = CallbackContext(job, gensym(:func), l_job)
lock(l_outstanding)
push!(outstanding, cc)
unlock(l_outstanding)

c_callback = @cfunction(callback, UInt64, (LLVM.API.LLVMOrcJITStackRef, Ptr{Cvoid}))

orc = jit[]
addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc_adjoint))
create_stub!(orc, string(cc_adjoint.stub), addr_adjoint)

if needs_augmented_primal
addr_primal = callback!(orc, c_callback, pointer_from_objref(cc_primal))
create_stub!(orc, string(cc_primal.stub), addr_primal)
addr_primal_stub = address(orc, string(cc_primal.stub))
else
addr_primal_stub = nothing
end
addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc))
create_stub!(orc, string(cc.stub), addr_adjoint)

return address(orc, string(cc_adjoint.stub)), addr_primal_stub
return address(orc, string(cc.stub))
end


Expand Down
Loading

0 comments on commit c531aae

Please sign in to comment.