diff --git a/src/compiler.jl b/src/compiler.jl index 61eeb81105..1c01624cf9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -594,19 +594,85 @@ function julia_undef_value_for_type( throw(AssertionError("Unknown type to val: $(Ty)")) end -function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef) +function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value))::Nothing + if Base.datatype_pointerfree(Ty) + return + end + + isboxed_ref = Ref{Bool}() + LLVMType = LLVM.LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef, + (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref)) + + if !isboxed_ref[] + zeroAll = false + T_int64 = LLVM.Int64Type() + prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev)))) + prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived)) + zero_single_allocation(B, Ty, LLVMType, prev, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) + else + @assert fieldcount(Ty) != 0 + + T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) + T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) + + T_int8 = LLVM.Int8Type() + T_int64 = LLVM.Int64Type() + + T_pint8 = LLVM.PointerType(T_int8) + + prev2 = bitcast!(B, prev, LLVM.PointerType(T_int8, addrspace(value_type(prev)))) + + for i in 1:fieldcount(Ty) + Ty2 = fieldtype(Ty, i) + off = fieldoffset(Ty, i) + + if Ty2 <: DataType && Base.datatype_pointerfree(Ty2) + continue + end + + prev3 = inbounds_gep!( + B, + T_int8, + prev2, + LLVM.Value[LLVM.ConstantInt(Int64(off))], + ) + + fallback = Base.isabstracttype(Ty2) || Ty2 isa Union + + @static if VERSION < v"1.11-" + fallback |= Ty2 <: Array + else + fallback |= Ty2 <: GenericMemory + end + + if fallback + Ty2 = Any + zeroAll = false + prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3)))) + if addrspace(value_type(prev3)) != Derived + prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived)) + end + zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) + else + create_recursive_stores(B, Ty2, prev3) + end + end + end +end + +function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, Orig::LLVM.API.LLVMValueRef, idx::UInt64, prev::API.LLVMValueRef) V = LLVM.CallInst(V) gutils = GradientUtils(gutils) mode = get_mode(gutils) + has, Ty, byref = abs_typeof(V) + if !has + throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) + end if mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient || mode == API.DEM_ReverseModeCombined fn = LLVM.parent(LLVM.parent(V)) world = enzyme_extract_world(fn) - has, Ty, byref = abs_typeof(V) - if !has - throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))")) - end rt = active_reg_inner(Ty, (), world) if rt == ActiveState || rt == MixedState B = LLVM.IRBuilder() @@ -614,6 +680,26 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie operands(V)[3] = unsafe_to_llvm(B, Base.RefValue{Ty}) end end + + if mode == API.DEM_ForwardMode + # Zero any jlvalue_t inner elements of preceeding allocation. + # Specifically in forward mode, you will first run the original allocation, + # then all shadow allocations. These allocations will thus all run before + # any value may store into them. For example, as follows: + # %orig = julia.gc_alloc(...) + # %"orig'" = julia.gcalloc(...) + # store orig[0] = jlvaluet + # store "orig'"[0] = jlvaluet' + # As a result, by the time of the subsequent GC allocation, the memory in the preceeding + # allocation might be undefined, and trigger a GC error. To avoid this, + # we will explicitly zero the GC'd fields of the previous allocation. + prev = LLVM.Instruction(prev) + B = LLVM.IRBuilder() + position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(prev))) + + create_recursive_stores(B, Ty, prev) + end + nothing end @@ -671,7 +757,7 @@ function zero_allocation(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMType return nothing end -function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value)) +function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value); write_barrier=false, atomic=false) T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) T_prjlvalue_UT = LLVM.PointerType(T_jlvalue) @@ -682,6 +768,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D jlType, )] + addedvals = LLVM.Value[] while length(todo) != 0 path, ty, jlty = popfirst!(todo) if isa(ty, LLVM.PointerType) @@ -689,12 +776,18 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D loc = gep!(builder, LLVMType, nobj, path) mod = LLVM.parent(LLVM.parent(Base.position(builder))) fill_val = unsafe_nothing_to_llvm(mod) + push!(addedvals, fill_val) loc = bitcast!( builder, loc, LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc))), ) - store!(builder, fill_val, loc) + st = store!(builder, fill_val, loc) + if atomic + ordering!(st, LLVM.API.LLVMAtomicOrderingRelease) + syncscope!(st, LLVM.SyncScope("singlethread")) + metadata(st)["enzymejl_atomicgc"] = LLVM.MDNode(LLVM.Metadata[]) + end elseif zeroAll loc = gep!(builder, LLVMType, nobj, path) store!(builder, LLVM.null(ty), loc) @@ -741,6 +834,10 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D continue end end + if length(addedvals) != 0 && write_barrier + pushfirst!(addedvals, get_base_and_offset(nobj; offsetAllowed=false, inttoptr=false)[1]) + emit_writebarrier!(builder, addedvals) + end return nothing end @@ -1126,7 +1223,7 @@ function __init__() @cfunction( shadow_alloc_rewrite, Cvoid, - (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef) + (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt64, LLVM.API.LLVMValueRef) ) ) register_alloc_rules() diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index f9769881a4..d3c4375abf 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -652,6 +652,7 @@ function addOptimizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachi jl_inst_simplify!(pm) jump_threading!(pm) dead_store_elimination!(pm) + add!(pm, FunctionPass("SafeAtomicToRegularStore", safe_atomic_to_regular_store!)) # More dead allocation (store) deletion before loop optimization # consider removing this: diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 2f9c61c0b4..3ef341706d 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -2399,3 +2399,21 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine) eraseInst(mod, func) end +function safe_atomic_to_regular_store!(f::LLVM.Function) + changed = false + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.StoreInst) + continue + end + if !haskey(metadata(inst), "enzymejl_atomicgc") + continue + end + Base.delete!(metadata(inst), "enzymejl_atomicgc") + syncscope!(inst, LLVM.SyncScope("system")) + ordering!(inst, LLVM.API.LLVMAtomicOrderingNotAtomic) + changed = true + end + return changed +end + +