Skip to content

Commit

Permalink
Custom rules overwritten fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 29, 2024
1 parent 3095a4f commit 9913e52
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4762,6 +4762,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
push!(jlrules, fname)
end

memcpy_alloca_to_loadstore(mod)

adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs)
toremove = []
# Inline the wrapper
Expand Down
116 changes: 116 additions & 0 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,122 @@ function source_elem(v)
end
end


## given code like
# % a = alloca
# ...
# memref(cast(%a), %b, constant size == sizeof(a))
#
# turn this into load/store, as this is more
# amenable to caching analysis infrastructure
function memcpy_alloca_to_loadstore(mod)
dl = datalayout(mod)
for f in functions(mod)
if length(blocks(f)) != 0
bb = first(blocks(f))
todel = Set{LLVM.Instruction}()
for alloca in instructions(bb)
if !isa(alloca, LLVM.AllocaInst)
continue
end
todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)]
copy = nothing
legal = true
elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca))
lifetimestarts = LLVM.Instruction[]
while length(todo) > 0
cur, prev = pop!(todo)
if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst)
for u in LLVM.uses(cur)
u = LLVM.user(u)
push!(todo, (u, cur))
end
continue
end
if isa(cur, LLVM.CallInst) && isa(LLVM.called_value(cur), LLVM.Function)
intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_value(cur))
if intr == LLVM.Intrinsic("llvm.lifetime.start").id
push!(lifetimestarts, cur)
continue
end
if intr == LLVM.Intrinsic("llvm.lifetime.end").id
continue
end
if intr == LLVM.Intrinsic("llvm.memcpy").id
sz = operands(cur)[3]
if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty)
if copy === nothing || copy == cur
copy = cur
continue
end
end
end
end

# read only insts of arg, don't matter
if isa(cur, LLVM.LoadInst)
continue
end
if isa(cur, LLVM.CallInst) && isa(LLVM.called_value(cur), LLVM.Function)
legalc = true
for (i, ci) in enumerate(operands(cur)[1:end-1])
if ci == prev
nocapture = false
readonly = false
for a in collect(parameter_attributes(LLVM.called_value(cur), i))
if kind(a) == kind(EnumAttribute("readonly"))
readonly = true
end
if kind(a) == kind(EnumAttribute("readnone"))
readonly = true
end
if kind(a) == kind(EnumAttribute("nocapture"))
nocapture = true
end
end
if !nocapture || !readonly
legalc = false
break
end
end
end
if legalc
continue
end
end

legal = false
break
end

if legal && copy !== nothing
B = LLVM.IRBuilder()
position!(B, copy)
dst = operands(copy)[1]
src = operands(copy)[2]
dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))))

dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst))))
src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src))))

src = load!(B, elty, src)
FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)])
lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT)
call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0])
store!(B, src, dst)
push!(todel, copy)
end
for lt in lifetimestarts
push!(todel, lt)
end
end
for inst in todel
unsafe_delete!(LLVM.parent(inst), inst)
end
end
end
end

# If there is a phi node of a decayed value, Enzyme may need to cache it
# Here we force all decayed pointer phis to first addrspace from 10
function nodecayed_phis!(mod::LLVM.Module)
Expand Down
42 changes: 25 additions & 17 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,20 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)

needsShadowJL = if RT <: Active
false
else
needsShadow
end

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))

curent_bb = position(B)
fn = LLVM.parent(curent_bb)
world = enzyme_extract_world(fn)

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten}
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}

mode = get_mode(gutils)

Expand Down Expand Up @@ -774,19 +780,19 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
if width != 1
ShadT = NTuple{Int(width), RealRt}
end
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, TapeT}
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT}
if aug_RT != ST
if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow
if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) !=
convert(LLVMType, EnzymeRules.shadow_type(ST) ; allow_boxed=true)
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
end
ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadow ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT}
ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT}
end
end
if aug_RT != ST
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, Any}
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any}
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
end
Expand All @@ -805,24 +811,26 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
idx+=1
end
if needsShadow
@assert !isghostty(RealRt)
shadowV = extract_value!(B, res, idx)
if get_return_info(RealRt)[2] !== nothing
dval = invert_pointer(gutils, operands(orig)[1], B)
if needsShadowJL
@assert !isghostty(RealRt)
shadowV = extract_value!(B, res, idx)
if get_return_info(RealRt)[2] !== nothing
dval = invert_pointer(gutils, operands(orig)[1], B)

for idx in 1:width
to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1)
for idx in 1:width
to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1)

store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1)
store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1)

store!(B, to_store, store_ptr)
store!(B, to_store, store_ptr)
end
shadowV = C_NULL
else
@assert value_type(shadowV) == shadowType
shadowV = shadowV.ref
end
shadowV = C_NULL
else
@assert value_type(shadowV) == shadowType
shadowV = shadowV.ref
idx+=1
end
idx+=1
end
if needsTape
tapeV = extract_value!(B, res, idx).ref
Expand Down
42 changes: 42 additions & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,46 @@ end
autodiff(Reverse, Const(cprimal), Active, Duplicated(x, dx), Duplicated(y, dy))
end

function remultr(arg)
arg * arg
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
::Type{<:Active}, args::Vararg{Active,N}) where {N}
primal = if EnzymeRules.needs_primal(config)
func.val(args[1].val)
else
nothing
end
return AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
dret::Active, tape, args::Vararg{Active,N}) where {N}

dargs = ntuple(Val(N)) do i
7 * args[1].val * dret.val
end
return dargs
end

function plaquette_sum(U)
p = eltype(U)(0)

for site in 1:length(U)
p += remultr(@inbounds U[site])
end

return p
end

@testset "No caching byref julia" begin
U = Complex{Float64}[3.0 + 4.0im]
dU = Complex{Float64}[0.0]

autodiff(Reverse, plaquette_sum, Active, Duplicated(U, dU))

@test dU[1] 7 * ( 3.0 + 4.0im )
end

end # ReverseRules

0 comments on commit 9913e52

Please sign in to comment.