diff --git a/src/compiler.jl b/src/compiler.jl index 7507e9af8f..2f66c407f9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4762,6 +4762,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; push!(jlrules, fname) end + memcpy_alloca_to_loadstore(mod) + adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs) toremove = [] # Inline the wrapper diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index b0b5b78179..cf7c77b0e5 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -70,6 +70,122 @@ function source_elem(v) end end + +## given code like +# % a = alloca +# ... +# memref(cast(%a), %b, constant size == sizeof(a)) +# +# turn this into load/store, as this is more +# amenable to caching analysis infrastructure +function memcpy_alloca_to_loadstore(mod) + dl = datalayout(mod) + for f in functions(mod) + if length(blocks(f)) != 0 + bb = first(blocks(f)) + todel = Set{LLVM.Instruction}() + for alloca in instructions(bb) + if !isa(alloca, LLVM.AllocaInst) + continue + end + todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)] + copy = nothing + legal = true + elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca)) + lifetimestarts = LLVM.Instruction[] + while length(todo) > 0 + cur, prev = pop!(todo) + if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst) + for u in LLVM.uses(cur) + u = LLVM.user(u) + push!(todo, (u, cur)) + end + continue + end + if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur)) + if intr == LLVM.Intrinsic("llvm.lifetime.start").id + push!(lifetimestarts, cur) + continue + end + if intr == LLVM.Intrinsic("llvm.lifetime.end").id + continue + end + if intr == LLVM.Intrinsic("llvm.memcpy").id + sz = operands(cur)[3] + if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty) + if copy === nothing || copy == cur + copy = cur + continue + end + end + end + end + + # read only insts of arg, don't matter + if isa(cur, LLVM.LoadInst) + continue + end + if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function) + legalc = true + for (i, ci) in enumerate(operands(cur)[1:end-1]) + if ci == prev + nocapture = false + readonly = false + for a in collect(parameter_attributes(LLVM.called_operand(cur), i)) + if kind(a) == kind(EnumAttribute("readonly")) + readonly = true + end + if kind(a) == kind(EnumAttribute("readnone")) + readonly = true + end + if kind(a) == kind(EnumAttribute("nocapture")) + nocapture = true + end + end + if !nocapture || !readonly + legalc = false + break + end + end + end + if legalc + continue + end + end + + legal = false + break + end + + if legal && copy !== nothing + B = LLVM.IRBuilder() + position!(B, copy) + dst = operands(copy)[1] + src = operands(copy)[2] + dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst)))) + + dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst)))) + src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src)))) + + src = load!(B, elty, src) + FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)]) + lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT) + call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0]) + store!(B, src, dst) + push!(todel, copy) + end + for lt in lifetimestarts + push!(todel, lt) + end + end + for inst in todel + unsafe_delete!(LLVM.parent(inst), inst) + end + end + end +end + # If there is a phi node of a decayed value, Enzyme may need to cache it # Here we force all decayed pointer phis to first addrspace from 10 function nodecayed_phis!(mod::LLVM.Module) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 9d9d0e5ef5..7bf62f6d66 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -490,6 +490,12 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall) RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt) + needsShadowJL = if RT <: Active + false + else + needsShadow + end + alloctx = LLVM.IRBuilder() position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils))) @@ -497,7 +503,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, fn = LLVM.parent(curent_bb) world = enzyme_extract_world(fn) - C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten} + C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten} mode = get_mode(gutils) @@ -774,7 +780,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, if width != 1 ShadT = NTuple{Int(width), RealRt} end - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT} if aug_RT != ST if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) != @@ -782,11 +788,11 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) return tapeV end - ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadow ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} + ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT} end end if aug_RT != ST - ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, Any} + ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any} emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT)) return tapeV end @@ -805,24 +811,26 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx+=1 end if needsShadow - @assert !isghostty(RealRt) - shadowV = extract_value!(B, res, idx) - if get_return_info(RealRt)[2] !== nothing - dval = invert_pointer(gutils, operands(orig)[1], B) + if needsShadowJL + @assert !isghostty(RealRt) + shadowV = extract_value!(B, res, idx) + if get_return_info(RealRt)[2] !== nothing + dval = invert_pointer(gutils, operands(orig)[1], B) - for idx in 1:width - to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) + for idx in 1:width + to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1) - store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) + store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1) - store!(B, to_store, store_ptr) + store!(B, to_store, store_ptr) + end + shadowV = C_NULL + else + @assert value_type(shadowV) == shadowType + shadowV = shadowV.ref end - shadowV = C_NULL - else - @assert value_type(shadowV) == shadowType - shadowV = shadowV.ref + idx+=1 end - idx+=1 end if needsTape tapeV = extract_value!(B, res, idx).ref diff --git a/test/rrules.jl b/test/rrules.jl index 17db1cc412..54c7c22802 100644 --- a/test/rrules.jl +++ b/test/rrules.jl @@ -258,4 +258,48 @@ end autodiff(Reverse, Const(cprimal), Active, Duplicated(x, dx), Duplicated(y, dy)) end +function remultr(arg) + arg * arg +end + +function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)}, + ::Type{<:Active}, args::Vararg{Active,N}) where {N} + primal = if EnzymeRules.needs_primal(config) + func.val(args[1].val) + else + nothing + end + return AugmentedReturn(primal, nothing, nothing) +end + +function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)}, + dret::Active, tape, args::Vararg{Active,N}) where {N} + + dargs = ntuple(Val(N)) do i + 7 * args[1].val * dret.val + end + return dargs +end + +function plaquette_sum(U) + p = eltype(U)(0) + + for site in 1:length(U) + p += remultr(@inbounds U[site]) + end + + return p +end + + +@static if VERSION >= v"1.9" +@testset "No caching byref julia" begin + U = Complex{Float64}[3.0 + 4.0im] + dU = Complex{Float64}[0.0] + + autodiff(Reverse, plaquette_sum, Active, Duplicated(U, dU)) + + @test dU[1] ≈ 7 * ( 3.0 + 4.0im ) +end +end end # ReverseRules