From 8784d1f79cf9e84028bc04c7455493d1b9dcbd31 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Thu, 29 Feb 2024 21:56:35 +0100 Subject: [PATCH 1/7] Fix triangular solve rule for `Adjoint`s (#1306) * Fix rule for nested wraps * Add test without constructor * Revert "Fix rule for nested wraps" This reverts commit e2643b7b72b7847f6212e9141b1ed06798c04338. * Start regression test (still working) * Make autodiff call through adjoint (broken) * Fix rule for nested wraps * Check derivatives * Make call signature follow bang convention * Rename testset and add explanation --- src/internal_rules.jl | 10 +++++----- test/internal_rules.jl | 22 +++++++++++++++++++++- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index bea48baed4..9bcce5925c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -493,7 +493,7 @@ function EnzymeRules.reverse( end if !isa(A, Const) dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - dA.data .-= _zero_unused_elements!(AT(z * adjoint(cache_Yout))) + dA.data .-= _zero_unused_elements!(z * adjoint(cache_Yout), A.val) end dY .= zero(eltype(dY)) end @@ -501,10 +501,10 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing) end -_zero_unused_elements!(A::UpperTriangular) = triu!(A.data) -_zero_unused_elements!(A::LowerTriangular) = tril!(A.data) -_zero_unused_elements!(A::UnitUpperTriangular) = triu!(A.data, 1) -_zero_unused_elements!(A::UnitLowerTriangular) = tril!(A.data, -1) +_zero_unused_elements!(X, ::UpperTriangular) = triu!(X) +_zero_unused_elements!(X, ::LowerTriangular) = tril!(X) +_zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) +_zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index b2ee39de22..f9b2aca957 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -396,7 +396,7 @@ end B = rand(TE, sizeB...) Y = zeros(TE, sizeB...) A = T(M) - @testset "test against EnzymeTestUtils through constructor" begin + @testset "test through constructor" begin _A = T(A) function f!(Y, A, B, ::T) where T ldiv!(Y, T(A), B) @@ -409,6 +409,26 @@ end test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) end end + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 + end end end end From c6642cde36916269450a8b7295fd61be5f476108 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Mar 2024 16:58:08 -0600 Subject: [PATCH 2/7] Fix docs (#1318) --- .github/workflows/CI.yml | 4 +--- .github/workflows/scripts_deploy.yml | 5 +---- examples/box.jl | 6 +++--- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a84c588f44..8c0ddd7380 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -220,9 +220,7 @@ jobs: - run: | julia --project=docs -e ' using Pkg - Pkg.develop(path="lib/EnzymeCore") - Pkg.develop(path="lib/EnzymeTestUtils") - Pkg.develop(PackageSpec(path=pwd())) + Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path="lib/EnzymeTestUtils"), PackageSpec(path=pwd())]) Pkg.instantiate()' env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager diff --git a/.github/workflows/scripts_deploy.yml b/.github/workflows/scripts_deploy.yml index 21ddbe248c..961a0bd3a6 100644 --- a/.github/workflows/scripts_deploy.yml +++ b/.github/workflows/scripts_deploy.yml @@ -19,10 +19,7 @@ jobs: - run: | julia --project=docs -e ' using Pkg - Pkg.develop(path="lib/EnzymeCore") - Pkg.develop(PackageSpec(path=pwd())) - Pkg.develop(path="lib/EnzymeTestUtils") - Pkg.develop(PackageSpec(path=pwd())) + Pkg.develop([PackageSpec(path="lib/EnzymeCore"), PackageSpec(path=pwd()), PackageSpec(path="lib/EnzymeTestUtils")]) Pkg.instantiate()' env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager diff --git a/examples/box.jl b/examples/box.jl index 5e3b5916f4..2d1b48b7ae 100644 --- a/examples/box.jl +++ b/examples/box.jl @@ -329,7 +329,7 @@ autodiff(Reverse, Duplicated([Tbar; Sbar], dstate_old), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) @@ -373,7 +373,7 @@ autodiff(Reverse, Duplicated([Tbar; Sbar], dstate_old_new), Duplicated(out_now, dout_now), Duplicated(out_old, dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) @@ -438,7 +438,7 @@ function compute_adjoint_values(states_before_smoother, states_after_smoother, M Duplicated(states_after_smoother[j], dstate_old), Duplicated(zeros(6), dout_now), Duplicated(zeros(6), dout_old), - parameters, + Const(parameters), Const(10*parameters.day) ) From ead0dc5b00cd10b4bd50cc5f70ce46aa9a8ab0da Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Mar 2024 17:45:24 -0600 Subject: [PATCH 3/7] Add tape_type query given a parent compiler job context (#1104) --- src/Enzyme.jl | 95 +++++++++++++++++++++++++++++++++++-------- src/absint.jl | 5 +++ src/compiler.jl | 25 ++++-------- src/compiler/orcv1.jl | 66 +++++++++--------------------- src/compiler/orcv2.jl | 50 +++++++++-------------- test/runtests.jl | 36 ++++++++++++++++ 6 files changed, 165 insertions(+), 112 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index f22411a8bd..c002ba4988 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -414,8 +414,8 @@ code, as well as high-order differentiation. ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) - @assert primal_ptr === nothing + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ReverseModeCombined), Val(width), ModifiedBetween, Val(ReturnPrimal)) + thunk = Compiler.CombinedAdjointThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), Val(ReturnPrimal)}(adjoint_ptr) if rt <: Active args = (args..., Compiler.default_adjoint(eltype(rt))) @@ -478,8 +478,7 @@ code, as well as high-order differentiation. ReturnPrimal = Val(RT <: Duplicated || RT <: BatchDuplicated) ModifiedBetween = Val(falses_from_args(Nargs+1)) - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) - @assert primal_ptr === nothing + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(tt′), Val(rt), Val(API.DEM_ForwardMode), Val(width), ModifiedBetween, ReturnPrimal) thunk = Compiler.ForwardModeThunk{Ptr{Cvoid}, FA, rt, tt′, typeof(Val(width)), ReturnPrimal}(adjoint_ptr) thunk(f, args...) end @@ -672,6 +671,71 @@ end return TapeType end +const tape_cache = Dict{UInt, Type}() + +const tape_cache_lock = ReentrantLock() + +import .Compiler: fspec, remove_innerty, UnknownTapeType + +@inline function tape_type( + parent_job::Union{GPUCompiler.CompilerJob,Nothing}, ::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, + ::Type{FA}, ::Type{A}, args... +) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI} + width = if Width == 0 + w = same_or_one(args...) + if w == 0 + throw(ErrorException("Cannot differentiate with a batch size of 0")) + end + w + else + Width + end + + if ModifiedBetweenT === true + ModifiedBetween = falses_from_args(Val(1), args...) + else + ModifiedBetween = ModifiedBetweenT + end + + @assert ReturnShadow + TT = Tuple{args...} + + primal_tt = Tuple{map(eltype, args)...} + + world = codegen_world_age(eltype(FA), primal_tt) + + mi = Compiler.fspec(eltype(FA), TT, world) + + target = Compiler.EnzymeTarget() + params = Compiler.EnzymeCompilerParams( + Tuple{FA, TT.parameters...}, API.DEM_ReverseModeGradient, width, + Compiler.remove_innerty(A), true, #=abiwrap=#false, ModifiedBetweenT, + ReturnPrimal, #=ShadowInit=#false, Compiler.UnknownTapeType, RABI + ) + job = Compiler.CompilerJob(mi, Compiler.CompilerConfig(target, params; kernel=false)) + + + key = hash(parent_job, hash(job)) + + # NOTE: no use of lock(::Function)/@lock/get! to keep stack traces clean + lock(tape_cache_lock) + + try + obj = get(tape_cache, key, nothing) + if obj === nothing + + Compiler.JuliaContext() do ctx + _, meta = Compiler.codegen(:llvm, job; optimize=false, parent_job) + obj = meta.TapeType + tape_cache[key] = obj + end + end + obj + finally + unlock(tape_cache_lock) + end +end + """ autodiff_deferred_thunk(::ReverseModeSplit, ftype, Activity, argtypes::Vararg{Type{<:Annotation}, Nargs}) @@ -703,7 +767,8 @@ function f(A, v) res end -forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)}) +TapeType = tape_type(ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(A)}, Active{typeof(v)}) +forward, reverse = autodiff_deferred_thunk(ReverseSplitWithPrimal, TapeType, Const{typeof(f)}, Active, Active{Float64}, Duplicated{typeof(A)}, Active{typeof(v)}) tape, result, shadow_result = forward(Const(f), Duplicated(A, ∂A), Active(v)) _, ∂v = reverse(Const(f), Duplicated(A, ∂A), Active(v), 1.0, tape)[1] @@ -715,7 +780,7 @@ result, ∂v, ∂A (7.26, 2.2, [3.3]) ``` """ -@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} +@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs} @assert RABI == FFIABI width = if Width == 0 w = same_or_one(args...) @@ -735,21 +800,15 @@ result, ∂v, ∂A @assert ReturnShadow TT = Tuple{args...} - + primal_tt = Tuple{map(eltype, args)...} world = codegen_world_age(eltype(FA), primal_tt) - # TODO this assumes that the thunk here has the correct parent/etc things for getting the right cuda instructions -> same caching behavior - nondef = Enzyme.Compiler.thunk(Val(world), FA, A, TT, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(false), RABI) - TapeType = EnzymeRules.tape_type(nondef[1]) - A2 = Compiler.return_type(typeof(nondef[1])) - - adjoint_ptr, primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) - AugT = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType} - @assert AugT == typeof(nondef[1]) - AdjT = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType} - @assert AdjT == typeof(nondef[2]) - AugT(primal_ptr), AdjT(adjoint_ptr) + primal_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModePrimal), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + adjoint_ptr = Compiler.deferred_codegen(Val(world), FA, Val(TT), Val(A2), Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, Val(ReturnPrimal), #=ShadowInit=#Val(false), TapeType) + aug_thunk = Compiler.AugmentedForwardThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, Val(ReturnPrimal), TapeType}(primal_ptr) + adj_thunk = Compiler.AdjointThunk{Ptr{Cvoid}, FA, A2, TT, Val{width}, TapeType}(adjoint_ptr) + aug_thunk, adj_thunk end # White lie, should be `Core.LLVMPtr{Cvoid, 0}` but that's not supported by ccallable diff --git a/src/absint.jl b/src/absint.jl index bcdd635e99..6216c1a769 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -112,6 +112,11 @@ function absint(arg::LLVM.Value, partial::Bool=false) return (false, nothing) end ptr = unsafe_load(reinterpret(Ptr{Ptr{Cvoid}}, convert(UInt, ce))) + if ptr == C_NULL + # XXX: Is this correct? + @error "Found null pointer" arg + return (false, nothing) + end typ = Base.unsafe_pointer_to_objref(ptr) return (true, typ) end diff --git a/src/compiler.jl b/src/compiler.jl index 3cdf51f03d..19ac839fc9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4984,7 +4984,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end - return mod, (;adjointf, augmented_primalf, entry=adjointf, compiled=meta.compiled, TapeType) + use_primal = mode == API.DEM_ReverseModePrimal + entry = use_primal ? augmented_primalf : adjointf + return mod, (;adjointf, augmented_primalf, entry, compiled=meta.compiled, TapeType) end # Compiler result @@ -5653,26 +5655,13 @@ import GPUCompiler: deferred_codegen_jobs params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - adjoint_addr, primal_addr = get_trampoline(job) - adjoint_id = Base.reinterpret(Int, pointer(adjoint_addr)) - deferred_codegen_jobs[adjoint_id] = job - - if primal_addr !== nothing - primal_id = Base.reinterpret(Int, pointer(primal_addr)) - deferred_codegen_jobs[primal_id] = job - else - primal_id = 0 - end + addr = get_trampoline(job) + id = Base.reinterpret(Int, pointer(addr)) + deferred_codegen_jobs[id] = job quote Base.@_inline_meta - adjoint = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, adjoint_id))) - primal = if $(primal_addr !== nothing) - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, primal_id))) - else - nothing - end - adjoint, primal + ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $(reinterpret(Ptr{Cvoid}, id))) end end end diff --git a/src/compiler/orcv1.jl b/src/compiler/orcv1.jl index 2af56896cb..1b6bd2fe81 100644 --- a/src/compiler/orcv1.jl +++ b/src/compiler/orcv1.jl @@ -39,16 +39,15 @@ function __init__() end mutable struct CallbackContext - tag::Symbol job::CompilerJob stub::Symbol l_job::ReentrantLock addr::Ptr{Cvoid} - CallbackContext(tag, job, stub, l_job) = new(tag, job, stub, l_job, C_NULL) + CallbackContext(job, stub, l_job) = new(job, stub, l_job, C_NULL) end const l_outstanding = Base.ReentrantLock() -const outstanding = Dict{Symbol, Tuple{CallbackContext, Union{Nothing, CallbackContext}}}() +const outstanding = Base.IdSet{CallbackContext}() # Setup the lazy callback for creating a module function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid}) @@ -61,35 +60,27 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid} # 2. lookup if we are the first lock(l_outstanding) - if haskey(outstanding, cc.tag) - ccs = outstanding[cc.tag] - delete!(outstanding, cc.tag) + if in(cc, outstanding) + delete!(outstanding, cc) else - ccs = nothing - end - unlock(l_outstanding) - - # 3. We are the second callback to run, but we raced the other one - # thus we return the addr from them. - if ccs === nothing + unlock(l_outstanding) unlock(cc.l_job) + + # 3. We are the second callback to run, but we raced the other one + # thus we return the addr from them. @assert cc.addr != C_NULL return UInt64(reinterpret(UInt, cc.addr)) end + unlock(l_outstanding) - cc_adjoint, cc_primal = ccs try thunk = Compiler._link(cc.job, Compiler._thunk(cc.job)) - cc_adjoint.addr = thunk.adjoint - if cc_primal !== nothing - cc_primal.addr = thunk.primal - end + mode = cc.job.config.params.mode + use_primal = mode == API.DEM_ReverseModePrimal + cc.addr = use_primal ? thunk.primal : thunk.adjoint # 4. Update the stub pointer to point to the recently compiled module - set_stub!(orc, string(cc_adjoint.stub), thunk.adjoint) - if cc_primal !== nothing - set_stub!(orc, string(cc_primal.stub), thunk.primal) - end + set_stub!(orc, string(cc.stub), cc.addr) finally unlock(cc.l_job) end @@ -101,37 +92,20 @@ function callback(orc_ref::LLVM.API.LLVMOrcJITStackRef, callback_ctx::Ptr{Cvoid} end function get_trampoline(job) - tag = gensym(:tag) l_job = Base.ReentrantLock() - mode = job.config.params.mode - needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient - - cc_adjoint = CallbackContext(tag, job, gensym(:adjoint), l_job) - if needs_augmented_primal - cc_primal = CallbackContext(tag, job, gensym(:primal), l_job) - else - cc_primal = nothing - end - lock(l_outstanding) do - outstanding[tag] = (cc_adjoint, cc_primal) - end + cc = CallbackContext(job, gensym(:func), l_job) + lock(l_outstanding) + push!(outstanding, cc) + unlock(l_outstanding) c_callback = @cfunction(callback, UInt64, (LLVM.API.LLVMOrcJITStackRef, Ptr{Cvoid})) orc = jit[] - addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc_adjoint)) - create_stub!(orc, string(cc_adjoint.stub), addr_adjoint) - - if needs_augmented_primal - addr_primal = callback!(orc, c_callback, pointer_from_objref(cc_primal)) - create_stub!(orc, string(cc_primal.stub), addr_primal) - addr_primal_stub = address(orc, string(cc_primal.stub)) - else - addr_primal_stub = nothing - end + addr_adjoint = callback!(orc, c_callback, pointer_from_objref(cc)) + create_stub!(orc, string(cc.stub), addr_adjoint) - return address(orc, string(cc_adjoint.stub)), addr_primal_stub + return address(orc, string(cc.stub)) end diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 4037fd70d0..644186407d 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -177,25 +177,15 @@ function get_trampoline(job) end mode = job.config.params.mode - needs_augmented_primal = mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient + use_primal = mode == API.DEM_ReverseModePrimal # We could also use one dylib per job jd = JITDylib(lljit) - adjoint_sym = String(gensym(:adjoint)) - _adjoint_sym = String(gensym(:adjoint)) - adjoint_addr = add_trampoline!(jd, (lljit, lctm, ism), - _adjoint_sym, adjoint_sym) - - if needs_augmented_primal - primal_sym = String(gensym(:augmented_primal)) - _primal_sym = String(gensym(:augmented_primal)) - primal_addr = add_trampoline!(jd, (lljit, lctm, ism), - _primal_sym, primal_sym) - else - primal_sym = nothing - primal_addr = nothing - end + sym = String(gensym(:func)) + _sym = String(gensym(:func)) + addr = add_trampoline!(jd, (lljit, lctm, ism), + _sym, sym) # 3. add MU that will call back into the compiler function materialize(mr) @@ -207,14 +197,18 @@ function get_trampoline(job) # 2. Call MR.replace(symbolAliases({"my_deferred_decision_sym.1" -> "foo.rt.impl"})). GPUCompiler.JuliaContext() do ctx mod, adjoint_name, primal_name = Compiler._thunk(job) - adjointf = functions(mod)[adjoint_name] - LLVM.name!(adjointf, adjoint_sym) - if needs_augmented_primal - primalf = functions(mod)[primal_name] - LLVM.name!(primalf, primal_sym) - else - @assert primal_name === nothing - primalf = nothing + func_name = use_primal ? primal_name : adjoint_name + other_name = !use_primal ? primal_name : adjoint_name + + func = functions(mod)[func_name] + LLVM.name!(func, sym) + + if other_name !== nothing + # Otherwise MR will complain -- we could claim responsibilty, + # but it would be nicer if _thunk just codegen'd the half + # we need. + other_func = functions(mod)[other_name] + LLVM.unsafe_delete!(mod, other_func) end tsm = move_to_threadsafe(mod) @@ -237,17 +231,13 @@ function get_trampoline(job) symbols = [ LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, adjoint_sym), flags), + mangle(lljit, sym), flags), ] - if needs_augmented_primal - push!(symbols, LLVM.API.LLVMOrcCSymbolFlagsMapPair( - mangle(lljit, primal_sym), flags),) - end - mu = LLVM.CustomMaterializationUnit(adjoint_sym, symbols, + mu = LLVM.CustomMaterializationUnit(sym, symbols, materialize, discard) LLVM.define(jd, mu) - return adjoint_addr, primal_addr + return addr end function add!(mod) diff --git a/test/runtests.jl b/test/runtests.jl index d4b541733a..8dc13b5da6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -299,6 +299,42 @@ make3() = (1.0, 2.0, 3.0) end +@testset "Deferred and deferred thunk" begin + function dot(A) + return A[1] * A[1] + A[2] * A[2] + end + dA = zeros(2) + A = [3.0, 5.0] + thunk_dA, def_dA = copy(dA), copy(dA) + def_A, thunk_A = copy(A), copy(A) + primal = Enzyme.autodiff(ReverseWithPrimal, dot, Active, Duplicated(A, dA))[2] + @test primal == 34.0 + primal = Enzyme.autodiff_deferred(ReverseWithPrimal, dot, Active, Duplicated(def_A, def_dA))[2] + @test primal == 34.0 + + dup = Duplicated(thunk_A, thunk_dA) + TapeType = Enzyme.EnzymeCore.tape_type( + ReverseSplitWithPrimal, + Const{typeof(dot)}, Active, Duplicated{typeof(thunk_A)} + ) + @test Tuple{Float64,Float64} === TapeType + fwd, rev = Enzyme.autodiff_deferred_thunk( + ReverseSplitWithPrimal, + TapeType, + Const{typeof(dot)}, + Active, + Active{Float64}, + Duplicated{typeof(thunk_A)} + ) + tape, primal, _ = fwd(Const(dot), dup) + @test isa(tape, Tuple{Float64,Float64}) + rev(Const(dot), dup, 1.0, tape) + @test all(primal == 34) + @test all(dA .== [6.0, 10.0]) + @test all(dA .== def_dA) + @test all(dA .== thunk_dA) +end + @testset "Simple Complex tests" begin mul2(z) = 2 * z square(z) = z * z From b71fc1c83f4aa27ce85fdd74ef1e5f82b9830b0d Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Mar 2024 18:51:01 -0600 Subject: [PATCH 4/7] Update compiler.jl (#1322) --- src/compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler.jl b/src/compiler.jl index 19ac839fc9..771dc896db 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -112,6 +112,7 @@ Dict{DataType, Tuple{Symbol, Int, Union{Nothing, Tuple{Symbol, DataType}}}}( end const nofreefns = Set{String}(( + "ijl_field_index", "jl_field_index", "julia.call", "julia.call2", "ijl_tagged_gensym", "jl_tagged_gensym", "ijl_array_ptr_copy", "jl_array_ptr_copy", From 8ff4a01fe1b4854dbb27b0008474c75e1284cf9b Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 4 Mar 2024 20:29:18 -0600 Subject: [PATCH 5/7] Loosen new struct requirements for jl_new_struct (#1323) --- src/rules/typeunstablerules.jl | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 0a1337e4bc..8a6a6beab4 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -68,6 +68,15 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap common_newstructv_fwd(offset, B, orig, gutils, normalR, shadowR) end +function error_if_active_newstruct(::Type{T}, ::Type{Y}) where {T, Y} + seen = () + areg = active_reg_inner(T, seen, nothing, #=justActive=#Val(true)) + if areg == ActiveState + throw(AssertionError("Found unhandled active variable ($T) in reverse mode of jl_newstruct constructor for $Y")) + end + nothing +end + function common_newstructv_rev(offset, B, orig, gutils, tape) if is_constant_value(gutils, orig) return true @@ -81,7 +90,22 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !needsShadow return end - emit_error(B, orig, "Enzyme: Not yet implemented reverse for jl_new_struct "*string(orig)*" "*string(operands(orig)[offset])*"\n"*string(LLVM.parent(orig))) + + origops = collect(operands(orig)) + width = get_width(gutils) + + world = enzyme_extract_world(LLVM.parent(position(B))) + + @assert is_constant_value(gutils, origops[offset]) + icvs = [is_constant_value(gutils, v) for v in origops[offset+1:end-1]] + abs = [abs_typeof(v, true) for v in origops[offset+1:end-1]] + + + ty = new_from_original(gutils, origops[offset]) + for v in origops[offset+1:end-1] + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(err_if_active_newstruct), emit_jltypeof!(B, v), ty]) + end + return nothing end From c8a4bf92678b4693cf6302f409c96cd022637697 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 4 Mar 2024 22:36:57 -0600 Subject: [PATCH 6/7] embarassing bugfix --- src/rules/typeunstablerules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 8a6a6beab4..ac24145102 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -103,7 +103,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) ty = new_from_original(gutils, origops[offset]) for v in origops[offset+1:end-1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(err_if_active_newstruct), emit_jltypeof!(B, v), ty]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, v), ty]) end return nothing From 5e4e2ef2ef9b8add6dd56e8afaf5a32039ac9f83 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 4 Mar 2024 22:42:15 -0600 Subject: [PATCH 7/7] embarassing bugfix 2 --- src/rules/typeunstablerules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index ac24145102..149ed46893 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -103,7 +103,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) ty = new_from_original(gutils, origops[offset]) for v in origops[offset+1:end-1] - emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, v), ty]) + emit_apply_generic!(B, LLVM.Value[unsafe_to_llvm(error_if_active_newstruct), emit_jltypeof!(B, new_from_original(gutils, v)), ty]) end return nothing