diff --git a/src/compiler.jl b/src/compiler.jl index 18e43cbdbc..8320d4ee1a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,10 +1,11 @@ module Compiler import ..Enzyme -import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, +import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, + BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, - API, TypeTree, typetree, only!, shift!, data0!, merge!, to_md, + API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype using Enzyme @@ -2933,6 +2934,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr push!(args_known_values, API.IntList()) end + seen = TypeTreeTable() for (i, T) in enumerate(TT.parameters) source_typ = eltype(T) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @@ -2958,8 +2960,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else error("illegal annotation type") end - typeTree = typetree(source_typ, ctx, dl) + typeTree = typetree(source_typ, ctx, dl, seen) if isboxed + typeTree = copy(typeTree) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) end @@ -3060,7 +3063,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr logic = Logic() TA = TypeAnalysis(logic, rules) - retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl) + retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType + retTT = typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -4037,6 +4042,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself")) end + seen = TypeTreeTable() # emit IR performing the "conversions" let builder = IRBuilder() toErase = LLVM.CallInst[] @@ -4100,7 +4106,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl), ctx) + metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, + dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -4128,7 +4135,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl), ctx) + metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), + ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 7bf62f6d66..7ec09e2c1d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -850,9 +850,10 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) + seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) - TT = typetree(Ty, ctx, dl) + TT = typetree(Ty, ctx, dl, seen) Typ = C_NULL ext = extract_value!(B, res, idx) shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(v))) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index cc5cea622e..4730db8654 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -17,7 +17,7 @@ function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) return UInt8(false) @@ -107,7 +107,7 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) diff --git a/src/typetree.jl b/src/typetree.jl index 13578fa82b..50cd399cc0 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -60,73 +60,102 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true) end function to_md(tt::TypeTree, ctx) - return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx))) + return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), + LLVM.API.LLVMValueRef, + (API.CTypeTreeRef, + LLVM.API.LLVMContextRef), tt, ctx))) end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer +const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}} + +""" + function typetree(T, ctx, dl, seen=TypeTreeTable()) + +Construct a Enzyme typetree from a Julia type. + +!!! warning + When using a memoized lookup by providing `seen` across multiple calls to typtree + the user must call `copy` on the returned value before mutating it. +""" +function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) + if haskey(seen, T) + tree = seen[T] + if tree === nothing + return TypeTree() # stop recursion, but don't cache + end + else + seen[T] = nothing # place recursion marker + tree = typetree_inner(T, ctx, dl, seen) + seen[T] = tree + end + return tree::TypeTree +end + +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer} return TypeTree(API.DT_Integer, -1, ctx) end -function typetree(::Type{Char}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Integer, -1, ctx) end -function typetree(::Type{Float16}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Half, -1, ctx) end -function typetree(::Type{Float32}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Float, -1, ctx) end -function typetree(::Type{Float64}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T<:AbstractFloat +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat} GPUCompiler.@safe_warn "Unknown floating point type" T return TypeTree() end -function typetree(::Type{<:DataType}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Any}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Symbol}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Core.SimpleVector}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() - for i in 0:(sizeof(Csize_t)-1) + for i in 0:(sizeof(Csize_t) - 1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt end -function typetree(::Type{Union{}}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:AbstractString}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen=nothing) where T - tt = typetree(T, ctx, dl, seen) +function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, + seen::TypeTreeTable) where {T} + tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt end -function typetree(::Type{<:Array{T}}, ctx, dl, seen=nothing) where T +function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} offset = 0 - tt = typetree(T, ctx, dl, seen) + tt = copy(typetree(T, ctx, dl, seen)) if !allocatedinline(T) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, 0) @@ -153,21 +182,11 @@ else ismutabletype(T) = isa(T, DataType) && T.mutable end -function typetree(@nospecialize(T), ctx, dl, seen=nothing) +function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) return TypeTree() end - if seen !== nothing && T ∈ seen - return TypeTree() - end - if seen === nothing - seen = Set{DataType}() - else - seen = copy(seen) # need to copy otherwise we'll count siblings as recursive - end - push!(seen, T) - if T === Tuple return TypeTree() end @@ -197,11 +216,12 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing) tt = TypeTree() for f in 1:fieldcount(T) - offset = fieldoffset(T, f) - subT = fieldtype(T, f) - subtree = typetree(subT, ctx, dl, seen) + offset = fieldoffset(T, f) + subT = fieldtype(T, f) + subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + # FIXME: Handle union continue end diff --git a/test/typetree.jl b/test/typetree.jl index cf7bf2e695..51c284d6e9 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -21,6 +21,22 @@ struct Composite y::Atom end +struct LList2{T} + next::Union{Nothing,LList2{T}} + v::T +end + +struct Sibling{T} + a::T + b::T +end + +struct Sibling2{T} + a::T + something::Bool + b::T +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -38,4 +54,28 @@ end @test at2.y == 0.0 @test at2.z == 0.0 @test at2.type == 4 + + if Sys.WORD_SIZE == 64 + @test tt(LList2{Float64}) == "{[8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + else + @test tt(LList2{Float64}) == "{[4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + end end