From 86dc34e13eeaf73d8d54c1572ffeac64711a77c3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 09:34:20 -0500 Subject: [PATCH 01/17] memoize typetree --- src/typetree.jl | 65 +++++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 13578fa82b..84d5ae6347 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -12,7 +12,7 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt) TypeTree() = TypeTree(API.EnzymeNewTypeTree()) TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx)) -function TypeTree(CT, idx, ctx) +function typetree_inner(CT, idx, ctx) tt = TypeTree(CT, ctx) only!(tt, idx) return tt @@ -63,44 +63,61 @@ function to_md(tt::TypeTree, ctx) return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx))) end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T <: Integer +const TypeTreeTable = IdDict{DataType, Union{Nothing, TypeTree}} + +function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) + if haskey(seen, T) + tree = seen[T] + if tree === nothing + return TypeTree() # stop recursion, but don't cache + else + return tree::TypeTree + end + else + seen[T] = nothing # place recursion marker + tree = typetree_inner(T, ctx, dl, seen) + seen[T] = tree + end +end + +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where T <: Integer return TypeTree(API.DT_Integer, -1, ctx) end -function typetree(::Type{Char}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Char}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Integer, -1, ctx) end -function typetree(::Type{Float16}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float16}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Half, -1, ctx) end -function typetree(::Type{Float32}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float32}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Float, -1, ctx) end -function typetree(::Type{Float64}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end -function typetree(::Type{T}, ctx, dl, seen=nothing) where T<:AbstractFloat +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where T<:AbstractFloat GPUCompiler.@safe_warn "Unknown floating point type" T return TypeTree() end -function typetree(::Type{<:DataType}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Any}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Any}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Symbol}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Symbol}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{Core.SimpleVector}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() for i in 0:(sizeof(Csize_t)-1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) @@ -108,25 +125,25 @@ function typetree(::Type{Core.SimpleVector}, ctx, dl, seen=nothing) return tt end -function typetree(::Type{Union{}}, ctx, dl, seen=nothing) +function typetree_inner(::Type{Union{}}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:AbstractString}, ctx, dl, seen=nothing) +function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen=nothing) where T - tt = typetree(T, ctx, dl, seen) +function typetree_inner(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen::TypeTreeTable) where T + tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt end -function typetree(::Type{<:Array{T}}, ctx, dl, seen=nothing) where T +function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where T offset = 0 - tt = typetree(T, ctx, dl, seen) + tt = copy(typetree(T, ctx, dl, seen)) if !allocatedinline(T) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, 0) @@ -153,21 +170,11 @@ else ismutabletype(T) = isa(T, DataType) && T.mutable end -function typetree(@nospecialize(T), ctx, dl, seen=nothing) +function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) if T isa UnionAll || T isa Union || T == Union{} || Base.isabstracttype(T) return TypeTree() end - if seen !== nothing && T ∈ seen - return TypeTree() - end - if seen === nothing - seen = Set{DataType}() - else - seen = copy(seen) # need to copy otherwise we'll count siblings as recursive - end - push!(seen, T) - if T === Tuple return TypeTree() end @@ -199,7 +206,7 @@ function typetree(@nospecialize(T), ctx, dl, seen=nothing) for f in 1:fieldcount(T) offset = fieldoffset(T, f) subT = fieldtype(T, f) - subtree = typetree(subT, ctx, dl, seen) + subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} continue From b6823fe7ef93753da21b41ce413670b8a124b5ea Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 09:37:20 -0500 Subject: [PATCH 02/17] fixup! memoize typetree --- src/typetree.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 84d5ae6347..ce041cd5b4 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -12,11 +12,11 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt) TypeTree() = TypeTree(API.EnzymeNewTypeTree()) TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx)) -function typetree_inner(CT, idx, ctx) - tt = TypeTree(CT, ctx) - only!(tt, idx) - return tt -end +# function typetree_inner(CT, idx, ctx) +# tt = TypeTree(CT, ctx) +# only!(tt, idx) +# return tt +# end Base.copy(tt::TypeTree) = TypeTree(API.EnzymeNewTypeTreeTR(tt)) Base.copy!(dst::TypeTree, src::TypeTree) = API.EnzymeSetTypeTree(dst, src) From 740600ece84ff3734f556be25d1a962d4ca50eba Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 09:39:23 -0500 Subject: [PATCH 03/17] fixup! memoize typetree --- src/typetree.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index ce041cd5b4..1ee50847f3 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -12,11 +12,11 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt) TypeTree() = TypeTree(API.EnzymeNewTypeTree()) TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx)) -# function typetree_inner(CT, idx, ctx) -# tt = TypeTree(CT, ctx) -# only!(tt, idx) -# return tt -# end +function TypeTree(CT, idx, ctx) + tt = TypeTree(CT, ctx) + only!(tt, idx) + return tt +end Base.copy(tt::TypeTree) = TypeTree(API.EnzymeNewTypeTreeTR(tt)) Base.copy!(dst::TypeTree, src::TypeTree) = API.EnzymeSetTypeTree(dst, src) @@ -76,7 +76,8 @@ function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) else seen[T] = nothing # place recursion marker tree = typetree_inner(T, ctx, dl, seen) - seen[T] = tree + seen[T] = tree + return tree::TypeTree end end From c88a9bf993cd6a4601364c06d3dd99dbd1b5cba3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 09:50:34 -0500 Subject: [PATCH 04/17] fixup! memoize typetree --- src/compiler.jl | 1 + src/rules/typerules.jl | 5 +++-- src/typetree.jl | 6 ++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8405ccbdec..34fff5f38d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2955,6 +2955,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr end typeTree = typetree(source_typ, ctx, dl) if isboxed + typetree = copy(typetree) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2c3e51e2ff..88097389ce 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -17,7 +17,7 @@ function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = copy(typetree(typ, ctx, dl)) only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) return UInt8(false) @@ -107,7 +107,7 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = typetree(typ, ctx, dl) + rest = copy(typetree(typ, ctx, dl)) only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) @@ -162,6 +162,7 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C rest = typetree(arg.typ, ctx, dl) @assert arg.cc == byref if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + rest = copy(rest) # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} # aka the next field after this in the bigger object isn't guaranteed to also be the same. diff --git a/src/typetree.jl b/src/typetree.jl index 1ee50847f3..a9748f3ebf 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -63,22 +63,20 @@ function to_md(tt::TypeTree, ctx) return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx))) end -const TypeTreeTable = IdDict{DataType, Union{Nothing, TypeTree}} +const TypeTreeTable = IdDict{Any, Union{Nothing, TypeTree}} function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) if haskey(seen, T) tree = seen[T] if tree === nothing return TypeTree() # stop recursion, but don't cache - else - return tree::TypeTree end else seen[T] = nothing # place recursion marker tree = typetree_inner(T, ctx, dl, seen) seen[T] = tree - return tree::TypeTree end + return tree::TypeTree end function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where T <: Integer From 0690f12dbe19c7fd2296cef6d9972435f902e343 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 10:08:41 -0500 Subject: [PATCH 05/17] fix typo --- src/compiler.jl | 14 ++++++++------ src/rules/customrules.jl | 3 ++- src/rules/typerules.jl | 13 +++++++------ src/typetree.jl | 9 +++++++++ 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 34fff5f38d..11bf61a92a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4,7 +4,7 @@ import ..Enzyme import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, - API, TypeTree, typetree, only!, shift!, data0!, merge!, to_md, + API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, TypeAnalysis, FnTypeInfo, Logic, allocatedinline, ismutabletype using Enzyme @@ -2928,6 +2928,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr push!(args_known_values, API.IntList()) end + seen = TypeTreeTable() for (i, T) in enumerate(TT.parameters) source_typ = eltype(T) if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) @@ -2953,9 +2954,9 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else error("illegal annotation type") end - typeTree = typetree(source_typ, ctx, dl) + typeTree = typetree(source_typ, ctx, dl, seen) if isboxed - typetree = copy(typetree) + typeTree = copy(typeTree) merge!(typeTree, TypeTree(API.DT_Pointer, ctx)) only!(typeTree, -1) end @@ -3061,7 +3062,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr logic = Logic() TA = TypeAnalysis(logic, rules) - retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl) + retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -4038,6 +4039,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself")) end + seen = TypeTreeTable() # emit IR performing the "conversions" let builder = IRBuilder() toErase = LLVM.CallInst[] @@ -4101,7 +4103,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl), ctx) + metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -4129,7 +4131,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl), ctx) + metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 7bf62f6d66..f012400c87 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -851,8 +851,9 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) - TT = typetree(Ty, ctx, dl) + TT = typetree(Ty, ctx, dl, seen) Typ = C_NULL ext = extract_value!(B, res, idx) shadowVType = LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(v))) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 88097389ce..e46b6d0dfa 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -17,7 +17,7 @@ function alloc_obj_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = copy(typetree(typ, ctx, dl)) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) return UInt8(false) @@ -107,7 +107,7 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT ctx = LLVM.context(LLVM.Value(val)) dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - rest = copy(typetree(typ, ctx, dl)) + rest = typetree(typ, ctx, dl) # copy unecessary since only user of `rest` only!(rest, -1) API.EnzymeMergeTypeTree(ret, rest) @@ -149,6 +149,7 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) jlargs = classify_arguments(mi.specTypes, called_type(inst), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) + seen = TypeTreeTable() for arg in jlargs if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam @@ -159,7 +160,7 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C @assert typ == arg.typ op_idx = arg.codegen.i - rest = typetree(arg.typ, ctx, dl) + rest = typetree(arg.typ, ctx, dl, seen) @assert arg.cc == byref if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF rest = copy(rest) @@ -202,13 +203,13 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C if sret !== nothing idx = 0 if !in(0, parmsRemoved) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl)) + API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl, seen)) idx+=1 end if returnRoots !== nothing if !in(1, parmsRemoved) allpointer = TypeTree(API.DT_Pointer, -1, ctx) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl)) + API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl, seen)) end end end @@ -217,7 +218,7 @@ function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.C if llRT !== nothing && value_type(inst) != LLVM.VoidType() @assert !retRemoved - API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl)) + API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl, seen)) end return UInt8(false) diff --git a/src/typetree.jl b/src/typetree.jl index a9748f3ebf..c79c49643c 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -65,6 +65,15 @@ end const TypeTreeTable = IdDict{Any, Union{Nothing, TypeTree}} +""" + function typetree(T, ctx, dl, seen=TypeTreeTable()) + +Construct a Enzyme typetree from a Julia type. + +!!! warning + When using a memoized lookup by providing `seen` across multiple calls to typtree + the user must call `copy` on the returned value before mutating it. +""" function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) if haskey(seen, T) tree = seen[T] From 47732ef1d3858b9503ba6423e8d46b0e61097f26 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 10:30:18 -0500 Subject: [PATCH 06/17] add test --- test/typetree.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/typetree.jl b/test/typetree.jl index cf7bf2e695..23a26c5538 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -21,6 +21,16 @@ struct Composite y::Atom end +struct LList2{T} + next::Union{Nothing, LList2{T}} + v::T +end + +struct Sibling + a::LList2{Float64} + b::LList2{Float64} +end + @testset "TypeTree" begin @test tt(Float16) == "{[-1]:Float@half}" @test tt(Float32) == "{[-1]:Float@float}" @@ -38,4 +48,7 @@ end @test at2.y == 0.0 @test at2.z == 0.0 @test at2.type == 4 + + @test tt(LList2{Float64}) == "{[8]:Float@double}" + @test tt(Sibling) == "{[-1]:Pointer, [-1,8]:Float@double}" end From 6a11645746978a5045ebfdabd1f358a1e7117566 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 21 Feb 2024 11:31:19 -0500 Subject: [PATCH 07/17] more tests --- src/typetree.jl | 1 + test/typetree.jl | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index c79c49643c..01b46297d3 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -217,6 +217,7 @@ function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} + # FIXME: Handle union continue end diff --git a/test/typetree.jl b/test/typetree.jl index 23a26c5538..333f839a99 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -26,9 +26,15 @@ struct LList2{T} v::T end -struct Sibling - a::LList2{Float64} - b::LList2{Float64} +struct Sibling{T} + a::T + b::T +end + +struct Sibling2{T} + a::T + something::Bool + b::T end @testset "TypeTree" begin @@ -50,5 +56,10 @@ end @test at2.type == 4 @test tt(LList2{Float64}) == "{[8]:Float@double}" - @test tt(Sibling) == "{[-1]:Pointer, [-1,8]:Float@double}" -end + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + @test tt(Sibling{Tuple{Int, Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int, Float64}}}) == "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32, Float64}}}}) == "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" +"{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + end From 4ba37440fa73252e1613b135997e44bca849d1dd Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:23:31 -0400 Subject: [PATCH 08/17] Use YAS style --- .JuliaFormatter.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index c7439503e1..7b1d4c2a25 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1 @@ -style = "blue" \ No newline at end of file +style = "YAS" \ No newline at end of file From baedd17442817d073b6440afce7e6bb811731f9f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:23:49 -0400 Subject: [PATCH 09/17] fixup --- .JuliaFormatter.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 7b1d4c2a25..9613e0542e 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1 @@ -style = "YAS" \ No newline at end of file +style = "yas" \ No newline at end of file From 567a7e0edf865caf31f14997732bbf741357845f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:36:23 -0400 Subject: [PATCH 10/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 11bf61a92a..ff127abf23 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,7 +1,8 @@ module Compiler import ..Enzyme -import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, +import Enzyme: Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, + BatchDuplicatedNoNeed, BatchDuplicatedFunc, Annotation, guess_activity, eltype, API, TypeTree, typetree, TypeTreeTable, only!, shift!, data0!, merge!, to_md, @@ -3062,7 +3063,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr logic = Logic() TA = TypeAnalysis(logic, rules) - retTT = typetree((!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType, ctx, dl, seen) + retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType + retTT = typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) From cbabfd94ba7bd3f832f1738bf85f33d61f076f8a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:38:42 -0400 Subject: [PATCH 11/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 9 ++++++--- src/rules/customrules.jl | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ff127abf23..30f9c29e98 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3063,7 +3063,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr logic = Logic() TA = TypeAnalysis(logic, rules) - retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType + retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType retTT = typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -4105,7 +4106,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) + metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, + dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -4133,7 +4135,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) + metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), + ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index f012400c87..7ec09e2c1d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -850,7 +850,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) TT = typetree(Ty, ctx, dl, seen) From 26937757f24c72d69dc9d503014a425294afa6cc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:39:04 -0400 Subject: [PATCH 12/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/typetree.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/typetree.jl b/src/typetree.jl index 01b46297d3..50cd399cc0 100644 --- a/src/typetree.jl +++ b/src/typetree.jl @@ -60,10 +60,13 @@ function merge!(dst::TypeTree, src::TypeTree; consume=true) end function to_md(tt::TypeTree, ctx) - return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), LLVM.API.LLVMValueRef, (API.CTypeTreeRef,LLVM.API.LLVMContextRef), tt, ctx))) + return LLVM.Metadata(LLVM.MetadataAsValue(ccall((:EnzymeTypeTreeToMD, API.libEnzyme), + LLVM.API.LLVMValueRef, + (API.CTypeTreeRef, + LLVM.API.LLVMContextRef), tt, ctx))) end -const TypeTreeTable = IdDict{Any, Union{Nothing, TypeTree}} +const TypeTreeTable = IdDict{Any,Union{Nothing,TypeTree}} """ function typetree(T, ctx, dl, seen=TypeTreeTable()) @@ -88,7 +91,7 @@ function typetree(@nospecialize(T), ctx, dl, seen=TypeTreeTable()) return tree::TypeTree end -function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where T <: Integer +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:Integer} return TypeTree(API.DT_Integer, -1, ctx) end @@ -108,7 +111,7 @@ function typetree_inner(::Type{Float64}, ctx, dl, seen::TypeTreeTable) return TypeTree(API.DT_Double, -1, ctx) end -function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where T<:AbstractFloat +function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat} GPUCompiler.@safe_warn "Unknown floating point type" T return TypeTree() end @@ -127,7 +130,7 @@ end function typetree_inner(::Type{Core.SimpleVector}, ctx, dl, seen::TypeTreeTable) tt = TypeTree() - for i in 0:(sizeof(Csize_t)-1) + for i in 0:(sizeof(Csize_t) - 1) merge!(tt, TypeTree(API.DT_Integer, i, ctx)) end return tt @@ -141,14 +144,15 @@ function typetree_inner(::Type{<:AbstractString}, ctx, dl, seen::TypeTreeTable) return TypeTree() end -function typetree_inner(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl, seen::TypeTreeTable) where T +function typetree_inner(::Type{<:Union{Ptr{T},Core.LLVMPtr{T}}}, ctx, dl, + seen::TypeTreeTable) where {T} tt = copy(typetree(T, ctx, dl, seen)) merge!(tt, TypeTree(API.DT_Pointer, ctx)) only!(tt, -1) return tt end -function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where T +function typetree_inner(::Type{<:Array{T}}, ctx, dl, seen::TypeTreeTable) where {T} offset = 0 tt = copy(typetree(T, ctx, dl, seen)) @@ -212,8 +216,8 @@ function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable) tt = TypeTree() for f in 1:fieldcount(T) - offset = fieldoffset(T, f) - subT = fieldtype(T, f) + offset = fieldoffset(T, f) + subT = fieldtype(T, f) subtree = copy(typetree(subT, ctx, dl, seen)) if subT isa UnionAll || subT isa Union || subT == Union{} From 2fc75893057ac59dd82f615f1bdcca33a241608d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 19 Mar 2024 12:39:23 -0400 Subject: [PATCH 13/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/typetree.jl | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/test/typetree.jl b/test/typetree.jl index 333f839a99..19a4e7b6b0 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -22,19 +22,19 @@ struct Composite end struct LList2{T} - next::Union{Nothing, LList2{T}} - v::T + next::Union{Nothing,LList2{T}} + v::T end struct Sibling{T} - a::T - b::T + a::T + b::T end struct Sibling2{T} - a::T - something::Bool - b::T + a::T + something::Bool + b::T end @testset "TypeTree" begin @@ -57,9 +57,13 @@ end @test tt(LList2{Float64}) == "{[8]:Float@double}" @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" - @test tt(Sibling2{LList2{Float64}}) == "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" - @test tt(Sibling{Tuple{Int, Float64}}) == "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" - @test tt(Sibling{LList2{Tuple{Int, Float64}}}) == "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" - @test tt(Sibling2{Sibling2{LList2{Tuple{Float32, Float64}}}}) == "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" -"{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" - end + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" +end From 9f34ed6056db2d53df07c865f0f28a311dddc63c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 25 Mar 2024 17:27:58 -0400 Subject: [PATCH 14/17] remove julia_type_tree --- src/rules/typerules.jl | 106 ----------------------------------------- 1 file changed, 106 deletions(-) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 50063d4a64..ef78291d0a 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -117,109 +117,3 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT return UInt8(false) end -function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - ctx = LLVM.context(inst) - - mi, RT = enzyme_custom_extract_mi(inst) - - ops = collect(operands(inst))[1:end-1] - called = LLVM.called_operand(inst) - - - llRT, sret, returnRoots = get_return_info(RT) - retRemoved, parmsRemoved = removed_ret_parms(inst) - - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - - expectLen = (sret !== nothing) + (returnRoots !== nothing) - for source_typ in mi.specTypes.parameters - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - continue - end - expectLen+=1 - end - expectLen -= length(parmsRemoved) - - # TODO fix the attributor inlining such that this can assert always true - if expectLen == length(ops) - - f = LLVM.called_operand(inst) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) - jlargs = classify_arguments(mi.specTypes, called_type(inst), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) - - seen = TypeTreeTable() - - for arg in jlargs - if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam - continue - end - - typ, byref = enzyme_extract_parm_type(f, arg.codegen.i) - @assert typ == arg.typ - - op_idx = arg.codegen.i - rest = typetree(arg.typ, ctx, dl, seen) - @assert arg.cc == byref - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - rest = copy(rest) - # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader - # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} - # aka the next field after this in the bigger object isn't guaranteed to also be the same. - if allocatedinline(arg.typ) - shift!(rest, dl, 0, sizeof(arg.typ), 0) - end - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - else - # canonicalize wrt size - end - PTT = unsafe_load(args, op_idx) - changed, legal = API.EnzymeCheckedMergeTypeTree(PTT, rest) - if !legal - function c(io) - println(io, "Illegal type analysis update from julia rule of method ", mi) - println(io, "Found type ", arg.typ, " at index ", arg.codegen.i, " of ", string(rest)) - t = API.EnzymeTypeTreeToString(PTT) - println(io, "Prior type ", Base.unsafe_string(t)) - println(io, inst) - API.EnzymeStringFree(t) - end - msg = sprint(c) - - bt = GPUCompiler.backtrace(inst) - ir = sprint(io->show(io, parent_scope(inst))) - - sval = "" - # data = API.EnzymeTypeAnalyzerRef(data) - # ip = API.EnzymeTypeAnalyzerToString(data) - # sval = Base.unsafe_string(ip) - # API.EnzymeStringFree(ip) - throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) - end - end - - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl, seen)) - idx+=1 - end - if returnRoots !== nothing - if !in(1, parmsRemoved) - allpointer = TypeTree(API.DT_Pointer, -1, ctx) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl, seen)) - end - end - end - - end - - if llRT !== nothing && value_type(inst) != LLVM.VoidType() - @assert !retRemoved - API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl, seen)) - end - - return UInt8(false) -end \ No newline at end of file From 275c6e63473ae2f049fe587baf2a6eb786e4f481 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 25 Mar 2024 22:32:06 -0400 Subject: [PATCH 15/17] Update src/rules/typerules.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rules/typerules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index ef78291d0a..4730db8654 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -116,4 +116,3 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT end return UInt8(false) end - From 41673301f6a62ada69d52bdbdbd56a1fba2410b6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 29 Mar 2024 12:29:38 -0400 Subject: [PATCH 16/17] Fix 32bit tests --- test/typetree.jl | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/test/typetree.jl b/test/typetree.jl index 19a4e7b6b0..d2801e0038 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -55,15 +55,27 @@ end @test at2.z == 0.0 @test at2.type == 4 - @test tt(LList2{Float64}) == "{[8]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" - @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" - @test tt(Sibling{Tuple{Int,Float64}}) == - "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" - @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" - @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + if Sys.WORD_SIZE == 64 + @test tt(LList2{Float64}) == "{[8]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + else + @test tt(LList2{Float64}) == "{[4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(Sibling2{LList2{Float64}}) == + "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + @test tt(Sibling{Tuple{Int,Float64}}) == + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" + @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == + "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == + "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + end end From 795eaa773b21e65253bef573ff212d02ab97f8b3 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 29 Mar 2024 12:36:32 -0400 Subject: [PATCH 17/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/typetree.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/typetree.jl b/test/typetree.jl index d2801e0038..51c284d6e9 100644 --- a/test/typetree.jl +++ b/test/typetree.jl @@ -59,23 +59,23 @@ end @test tt(LList2{Float64}) == "{[8]:Float@double}" @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,8]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" + "{[0]:Pointer, [0,8]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == - "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer, [8]:Float@double, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" + "{[-1]:Pointer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" + "{[0]:Pointer, [0,8]:Float@float, [0,16]:Float@double, [8]:Integer, [16]:Pointer, [16,8]:Float@float, [16,16]:Float@double, [24]:Integer, [32]:Pointer, [32,8]:Float@float, [32,16]:Float@double, [40]:Integer, [48]:Pointer, [48,8]:Float@float, [48,16]:Float@double}" else @test tt(LList2{Float64}) == "{[4]:Float@double}" - @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" + @test tt(Sibling{LList2{Float64}}) == "{[-1]:Pointer, [-1,4]:Float@double}" @test tt(Sibling2{LList2{Float64}}) == - "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" + "{[0]:Pointer, [0,4]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@double}" @test tt(Sibling{Tuple{Int,Float64}}) == - "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" + "{[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Float@double, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Float@double}" @test tt(Sibling{LList2{Tuple{Int,Float64}}}) == - "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" + "{[-1]:Pointer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer, [-1,8]:Float@double}" @test tt(Sibling2{Sibling2{LList2{Tuple{Float32,Float64}}}}) == - "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" + "{[0]:Pointer, [0,4]:Float@float, [0,8]:Float@double, [4]:Integer, [8]:Pointer, [8,4]:Float@float, [8,8]:Float@double, [12]:Integer, [16]:Pointer, [16,4]:Float@float, [16,8]:Float@double, [20]:Integer, [24]:Pointer, [24,4]:Float@float, [24,8]:Float@double}" end end