Skip to content

Commit

Permalink
Custom rules overwritten fix (#1258)
Browse files Browse the repository at this point in the history
* Custom rules overwritten fix

* caching v1.9+
  • Loading branch information
wsmoses authored Jan 29, 2024
1 parent d4f6400 commit ba58bcb
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4762,6 +4762,8 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
push!(jlrules, fname)
end

memcpy_alloca_to_loadstore(mod)

adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs)
toremove = []
# Inline the wrapper
Expand Down
116 changes: 116 additions & 0 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,122 @@ function source_elem(v)
end
end


## given code like
# % a = alloca
# ...
# memref(cast(%a), %b, constant size == sizeof(a))
#
# turn this into load/store, as this is more
# amenable to caching analysis infrastructure
function memcpy_alloca_to_loadstore(mod)
dl = datalayout(mod)
for f in functions(mod)
if length(blocks(f)) != 0
bb = first(blocks(f))
todel = Set{LLVM.Instruction}()
for alloca in instructions(bb)
if !isa(alloca, LLVM.AllocaInst)
continue
end
todo = Tuple{LLVM.Instruction, LLVM.Value}[(alloca, alloca)]
copy = nothing
legal = true
elty = LLVM.LLVMType(LLVM.API.LLVMGetAllocatedType(alloca))
lifetimestarts = LLVM.Instruction[]
while length(todo) > 0
cur, prev = pop!(todo)
if isa(cur, LLVM.AllocaInst) || isa(cur, LLVM.AddrSpaceCastInst) || isa(cur, LLVM.BitCastInst)
for u in LLVM.uses(cur)
u = LLVM.user(u)
push!(todo, (u, cur))
end
continue
end
if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function)
intr = LLVM.API.LLVMGetIntrinsicID(LLVM.called_operand(cur))
if intr == LLVM.Intrinsic("llvm.lifetime.start").id
push!(lifetimestarts, cur)
continue
end
if intr == LLVM.Intrinsic("llvm.lifetime.end").id
continue
end
if intr == LLVM.Intrinsic("llvm.memcpy").id
sz = operands(cur)[3]
if operands(cur)[1] == prev && isa(sz, LLVM.ConstantInt) && convert(Int, sz) == sizeof(dl, elty)
if copy === nothing || copy == cur
copy = cur
continue
end
end
end
end

# read only insts of arg, don't matter
if isa(cur, LLVM.LoadInst)
continue
end
if isa(cur, LLVM.CallInst) && isa(LLVM.called_operand(cur), LLVM.Function)
legalc = true
for (i, ci) in enumerate(operands(cur)[1:end-1])
if ci == prev
nocapture = false
readonly = false
for a in collect(parameter_attributes(LLVM.called_operand(cur), i))
if kind(a) == kind(EnumAttribute("readonly"))
readonly = true
end
if kind(a) == kind(EnumAttribute("readnone"))
readonly = true
end
if kind(a) == kind(EnumAttribute("nocapture"))
nocapture = true
end
end
if !nocapture || !readonly
legalc = false
break
end
end
end
if legalc
continue
end
end

legal = false
break
end

if legal && copy !== nothing
B = LLVM.IRBuilder()
position!(B, copy)
dst = operands(copy)[1]
src = operands(copy)[2]
dst0 = bitcast!(B, dst, LLVM.PointerType(LLVM.IntType(8), addrspace(value_type(dst))))

dst = bitcast!(B, dst, LLVM.PointerType(elty, addrspace(value_type(dst))))
src = bitcast!(B, src, LLVM.PointerType(elty, addrspace(value_type(src))))

src = load!(B, elty, src)
FT = LLVM.FunctionType(LLVM.VoidType(), [LLVM.IntType(64), value_type(dst0)])
lifetimestart, _ = get_function!(mod, "llvm.lifetime.start.p0i8", FT)
call!(B, FT, lifetimestart, LLVM.Value[LLVM.ConstantInt(Int64(sizeof(dl, elty))), dst0])
store!(B, src, dst)
push!(todel, copy)
end
for lt in lifetimestarts
push!(todel, lt)
end
end
for inst in todel
unsafe_delete!(LLVM.parent(inst), inst)
end
end
end
end

# If there is a phi node of a decayed value, Enzyme may need to cache it
# Here we force all decayed pointer phis to first addrspace from 10
function nodecayed_phis!(mod::LLVM.Module)
Expand Down
42 changes: 25 additions & 17 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,20 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
args, activity, overwritten, actives, kwtup = enzyme_custom_setup_args(B, orig, gutils, mi, RealRt, #=reverse=#!forward, isKWCall)
RT, needsPrimal, needsShadow, origNeedsPrimal = enzyme_custom_setup_ret(gutils, orig, mi, RealRt)

needsShadowJL = if RT <: Active
false
else
needsShadow
end

alloctx = LLVM.IRBuilder()
position!(alloctx, LLVM.BasicBlock(API.EnzymeGradientUtilsAllocationBlock(gutils)))

curent_bb = position(B)
fn = LLVM.parent(curent_bb)
world = enzyme_extract_world(fn)

C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadow), Int(width), overwritten}
C = EnzymeRules.Config{Bool(needsPrimal), Bool(needsShadowJL), Int(width), overwritten}

mode = get_mode(gutils)

Expand Down Expand Up @@ -774,19 +780,19 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
if width != 1
ShadT = NTuple{Int(width), RealRt}
end
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, TapeT}
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, TapeT}
if aug_RT != ST
if aug_RT <: EnzymeRules.AugmentedReturnFlexShadow
if convert(LLVMType, EnzymeRules.shadow_type(aug_RT); allow_boxed=true) !=
convert(LLVMType, EnzymeRules.shadow_type(ST) ; allow_boxed=true)
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " flex shadow ABI return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
end
ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadow ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT}
ST = EnzymeRules.AugmentedReturnFlexShadow{needsPrimal ? RealRt : Nothing, needsShadowJL ? EnzymeRules.shadow_type(aug_RT) : Nothing, TapeT}
end
end
if aug_RT != ST
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadow ? ShadT : Nothing, Any}
ST = EnzymeRules.AugmentedReturn{needsPrimal ? RealRt : Nothing, needsShadowJL ? ShadT : Nothing, Any}
emit_error(B, orig, "Enzyme: Augmented forward pass custom rule " * string(augprimal_TT) * " return type mismatch, expected "*string(ST)*" found "* string(aug_RT))
return tapeV
end
Expand All @@ -805,24 +811,26 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
idx+=1
end
if needsShadow
@assert !isghostty(RealRt)
shadowV = extract_value!(B, res, idx)
if get_return_info(RealRt)[2] !== nothing
dval = invert_pointer(gutils, operands(orig)[1], B)
if needsShadowJL
@assert !isghostty(RealRt)
shadowV = extract_value!(B, res, idx)
if get_return_info(RealRt)[2] !== nothing
dval = invert_pointer(gutils, operands(orig)[1], B)

for idx in 1:width
to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1)
for idx in 1:width
to_store = (width == 1) ? shadowV : extract_value!(B, shadowV, idx-1)

store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1)
store_ptr = (width == 1) ? dval : extract_value!(B, dval, idx-1)

store!(B, to_store, store_ptr)
store!(B, to_store, store_ptr)
end
shadowV = C_NULL
else
@assert value_type(shadowV) == shadowType
shadowV = shadowV.ref
end
shadowV = C_NULL
else
@assert value_type(shadowV) == shadowType
shadowV = shadowV.ref
idx+=1
end
idx+=1
end
if needsTape
tapeV = extract_value!(B, res, idx).ref
Expand Down
44 changes: 44 additions & 0 deletions test/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,4 +258,48 @@ end
autodiff(Reverse, Const(cprimal), Active, Duplicated(x, dx), Duplicated(y, dy))
end

function remultr(arg)
arg * arg
end

function EnzymeRules.augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
::Type{<:Active}, args::Vararg{Active,N}) where {N}
primal = if EnzymeRules.needs_primal(config)
func.val(args[1].val)
else
nothing
end
return AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
dret::Active, tape, args::Vararg{Active,N}) where {N}

dargs = ntuple(Val(N)) do i
7 * args[1].val * dret.val
end
return dargs
end

function plaquette_sum(U)
p = eltype(U)(0)

for site in 1:length(U)
p += remultr(@inbounds U[site])
end

return p
end


@static if VERSION >= v"1.9"
@testset "No caching byref julia" begin
U = Complex{Float64}[3.0 + 4.0im]
dU = Complex{Float64}[0.0]

autodiff(Reverse, plaquette_sum, Active, Duplicated(U, dU))

@test dU[1] 7 * ( 3.0 + 4.0im )
end
end
end # ReverseRules

2 comments on commit ba58bcb

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/99733

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.11.14 -m "<description of version>" ba58bcbb1a64cb131335862371e3059c147884ee
git push origin v0.11.14

Please sign in to comment.