Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memoize typetree calls #1302

Merged
merged 20 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module Compiler

import ..Enzyme
import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed,
import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated,
BatchDuplicatedNoNeed,
BatchDuplicatedFunc,
Annotation, guess_activity, eltype,
API, TypeTree, typetree, only!, shift!, data0!, merge!, to_md,
API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md,
TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype
using Enzyme

Expand Down Expand Up @@ -2933,6 +2934,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
push!(args_known_values, API.IntList())
end

seen = TypeTreeTable()
for (i, T) in enumerate(TT.parameters)
source_typ = eltype(T)
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
Expand All @@ -2958,8 +2960,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
else
error("illegal annotation type")
end
typeTree = typetree(source_typ, ctx, dl)
typeTree = typetree(source_typ, ctx, dl, seen)
if isboxed
typeTree = copy(typeTree)
merge!(typeTree, TypeTree(API.DT_Pointer, ctx))
only!(typeTree, -1)
end
Expand Down Expand Up @@ -3060,7 +3063,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
logic = Logic()
TA = TypeAnalysis(logic, rules)

retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl)
retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ?
Ptr{actualRetType} : actualRetType
retTT = typetree(retT, ctx, dl, seen)

typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values)

Expand Down Expand Up @@ -4037,6 +4042,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself"))
end

seen = TypeTreeTable()
# emit IR performing the "conversions"
let builder = IRBuilder()
toErase = LLVM.CallInst[]
Expand Down Expand Up @@ -4100,7 +4106,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if RetActivity <: Const
metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl), ctx)
metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx,
dl, seen), ctx)
push!(wrapper_args, sretPtr)
end
if returnRoots && !in(1, parmsRemoved)
Expand Down Expand Up @@ -4128,7 +4135,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
ctx = LLVM.context(entry_f)
metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl), ctx)
metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen),
ctx)
if LLVM.addrspace(ty) != 0
ptr = addrspacecast!(builder, ptr, ty)
end
Expand Down
5 changes: 3 additions & 2 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,9 +850,10 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

idx = 0
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig)))))
Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active)
Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active)
seen = TypeTreeTable()
for (v, Ty) in zip(actives, Tys2)
TT = typetree(Ty, ctx, dl)
TT = typetree(Ty, ctx, dl, seen)
Typ = C_NULL
ext = extract_value!(B, res, idx)
shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(v)))
Expand Down
4 changes: 2 additions & 2 deletions src/rules/typerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CT
ctx = LLVM.context(LLVM.Value(val))
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))

rest = typetree(typ, ctx, dl)
rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest`
only!(rest, -1)
API.EnzymeMergeTypeTree(ret, rest)
return UInt8(false)
Expand Down Expand Up @@ -107,7 +107,7 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT
ctx = LLVM.context(LLVM.Value(val))
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))

rest = typetree(typ, ctx, dl)
rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest`
only!(rest, -1)
API.EnzymeMergeTypeTree(ret, rest)

Expand Down
84 changes: 52 additions & 32 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,73 +60,102 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true)
end

function to_md(tt::TypeTree, ctx)
return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx)))
return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme),
LLVM.API.LLVMValueRef,
(API.CTypeTreeRef,
LLVM.API.LLVMContextRef), tt, ctx)))
end

function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer
const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}}

"""
function typetree(T, ctx, dl, seen=TypeTreeTable())

Construct a Enzyme typetree from a Julia type.

!!! warning
When using a memoized lookup by providing `seen` across multiple calls to typtree
the user must call `copy` on the returned value before mutating it.
"""
function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable())
if haskey(seen, T)
tree = seen[T]
if tree === nothing
return TypeTree() # stop recursion, but don't cache
end
else
seen[T] = nothing # place recursion marker
tree = typetree_inner(T, ctx, dl, seen)
seen[T] = tree
end
return tree::TypeTree
end

function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer}
return TypeTree(API.DT_Integer, -1, ctx)
end

function typetree(::Type{Char}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable)
return TypeTree(API.DT_Integer, -1, ctx)
end

function typetree(::Type{Float16}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable)
return TypeTree(API.DT_Half, -1, ctx)
end

function typetree(::Type{Float32}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable)
return TypeTree(API.DT_Float, -1, ctx)
end

function typetree(::Type{Float64}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable)
return TypeTree(API.DT_Double, -1, ctx)
end

function typetree(::Type{T}, ctx, dl, seen=nothing) where T<:AbstractFloat
function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat}
GPUCompiler.@safe_warn "Unknown floating point type" T
return TypeTree()
end

function typetree(::Type{<:DataType}, ctx, dl, seen=nothing)
function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree(::Type{Any}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree(::Type{Symbol}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree(::Type{Core.SimpleVector}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable)
tt = TypeTree()
for i in 0:(sizeof(Csize_t)-1)
for i in 0:(sizeof(Csize_t) - 1)
merge!(tt, TypeTree(API.DT_Integer, i, ctx))
end
return tt
end

function typetree(::Type{Union{}}, ctx, dl, seen=nothing)
function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree(::Type{<:AbstractString}, ctx, dl, seen=nothing)
function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen=nothing) where T
tt = typetree(T, ctx, dl, seen)
function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl,
seen::TypeTreeTable) where {T}
tt = copy(typetree(T, ctx, dl, seen))
merge!(tt, TypeTree(API.DT_Pointer, ctx))
only!(tt, -1)
return tt
end

function typetree(::Type{<:Array{T}}, ctx, dl, seen=nothing) where T
function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T}
offset = 0

tt = typetree(T, ctx, dl, seen)
tt = copy(typetree(T, ctx, dl, seen))
if !allocatedinline(T)
merge!(tt, TypeTree(API.DT_Pointer, ctx))
only!(tt, 0)
Expand All @@ -153,21 +182,11 @@ else
ismutabletype(T) = isa(T, DataType) && T.mutable
end

function typetree(@nospecialize(T), ctx, dl, seen=nothing)
function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable)
if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T)
return TypeTree()
end

if seen !== nothing && T ∈ seen
return TypeTree()
end
vchuravy marked this conversation as resolved.
Show resolved Hide resolved
if seen === nothing
seen = Set{DataType}()
else
seen = copy(seen) # need to copy otherwise we'll count siblings as recursive
end
push!(seen, T)

if T === Tuple
return TypeTree()
end
Expand Down Expand Up @@ -197,11 +216,12 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing)

tt = TypeTree()
for f in 1:fieldcount(T)
offset = fieldoffset(T, f)
subT = fieldtype(T, f)
subtree = typetree(subT, ctx, dl, seen)
offset = fieldoffset(T, f)
subT = fieldtype(T, f)
subtree = copy(typetree(subT, ctx, dl, seen))

if subT isa UnionAll || subT isa Union || subT == Union{}
# FIXME: Handle union
continue
end

Expand Down
40 changes: 40 additions & 0 deletions test/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ struct Composite
y::Atom
end

struct LList2{T}
next::Union{Nothing,LList2{T}}
v::T
end

struct Sibling{T}
a::T
b::T
end

struct Sibling2{T}
a::T
something::Bool
b::T
end

@testset "TypeTree" begin
@test tt(Float16) == "{[-1]:Float@half}"
@test tt(Float32) == "{[-1]:Float@float}"
Expand All @@ -38,4 +54,28 @@ end
@test at2.y == 0.0
@test at2.z == 0.0
@test at2.type == 4

if Sys.WORD_SIZE == 64
@test tt(LList2{Float64}) == "{[8]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}"
@test tt(Sibling2{LList2{Float64}}) ==
"{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}"
@test tt(Sibling{Tuple{Int,Float64}}) ==
"{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}"
@test tt(Sibling{LList2{Tuple{Int,Float64}}}) ==
"{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}"
@test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) ==
"{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}"
else
@test tt(LList2{Float64}) == "{[4]:Float@double}"
@test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}"
@test tt(Sibling2{LList2{Float64}}) ==
"{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}"
@test tt(Sibling{Tuple{Int,Float64}}) ==
"{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}"
@test tt(Sibling{LList2{Tuple{Int,Float64}}}) ==
"{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}"
@test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) ==
"{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}"
end
end
Loading