Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 1, 2024
1 parent 62de00b commit 94862cf
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 84 deletions.
9 changes: 8 additions & 1 deletion src/absint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,21 @@ function abs_typeof(
return (false, nothing, nothing)
end

@inline function is_zero(@nospecialize(x::LLVM.Value))::Bool
if x isa LLVM.ConstantInt
return convert(UInt, x) == 0
end
return false
end

function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String}
if isa(arg, ConstantExpr)
ce = arg
while isa(ce, ConstantExpr)
if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr
ce = operands(ce)[1]
elseif opcode(ce) == LLVM.API.LLVMGetElementPtr
if all(x -> x isa LLVM.ConstantInt && convert(UInt, x) == 0, operands(ce)[2:end])
if all(is_zero, operands(ce)[2:end])
ce = operands(ce)[1]
else
break
Expand Down
28 changes: 19 additions & 9 deletions src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,13 @@ function get_pgcstack(func::LLVM.Function)
end

function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing, LLVM.IRBuilder}) = nothing)
for (i, v) in enumerate(parameters(func))
if any(
map(
k -> kind(k) == kind(EnumAttribute("swiftself")),
collect(parameter_attributes(func, i)),
),
)
return v
for i in 1:length(LLVM.parameters(fn))
for attr in collect(LLVM.parameter_attributes(fn, i))
if attr isa LLVM.EnumAttribute
if kind(attr) == swiftself_kind
return parameters(fn)[i]
end
end
end
end

Expand Down Expand Up @@ -366,7 +365,7 @@ end
const swiftself_kind = enum_attr_kind("swiftself")

function has_swiftself(fn::LLVM.Function)::Bool
for i in 1:size(LLVM.parameters(fn))
for i in 1:length(LLVM.parameters(fn))
for attr in collect(LLVM.parameter_attributes(fn, i))
if attr isa LLVM.EnumAttribute
if kind(attr) == swiftself_kind
Expand All @@ -388,6 +387,17 @@ function has_fn_attr(fn::LLVM.Function, attr::LLVM.EnumAttribute)::Bool
end
return false
end
function has_fn_attr(fn::LLVM.Function, attr::LLVM.StringAttribute)::Bool
ekind = LLVM.kind(attr)
for attr in collect(function_attributes(fn))
if attr isa LLVM.StringAttribute
if kind(attr) == ekind
return true
end
end
end
return false
end

function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction))
@static if isdefined(LLVM, Symbol("erase!"))
Expand Down
37 changes: 5 additions & 32 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,14 +542,7 @@ end

push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

swiftself = any(
any(
map(
k -> kind(k) == kind(EnumAttribute("swiftself")),
collect(parameter_attributes(llvmf, i)),
),
) for i = 1:length(collect(parameters(llvmf)))
)
swiftself = has_swiftself(llvmf)
if swiftself
pushfirst!(reinsert_gcmarker!(fn, B))
end
Expand Down Expand Up @@ -596,12 +589,7 @@ end
debug_from_orig!(gutils, res, orig)
callconv!(res, callconv(llvmf))

hasNoRet = any(
map(
k -> kind(k) == kind(EnumAttribute("noreturn")),
collect(function_attributes(llvmf)),
),
)
hasNoRet = has_fn_attribute(llvmf, EnumAttribute("noreturn"))

if hasNoRet
return false
Expand Down Expand Up @@ -1083,14 +1071,7 @@ function enzyme_custom_common_rev(
# llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}, world)
# end

swiftself = any(
any(
map(
k -> kind(k) == kind(EnumAttribute("swiftself")),
collect(parameter_attributes(llvmf, i)),
),
) for i = 1:length(collect(parameters(llvmf)))
)
swiftself = has_swiftself(llvmf)

miRT = enzyme_custom_extract_mi(llvmf)[2]
_, sret, returnRoots = get_return_info(miRT)
Expand Down Expand Up @@ -1302,12 +1283,7 @@ function enzyme_custom_common_rev(
debug_from_orig!(gutils, res, orig)
callconv!(res, callconv(llvmf))

hasNoRet = any(
map(
k -> kind(k) == kind(EnumAttribute("noreturn")),
collect(function_attributes(llvmf)),
),
)
hasNoRet = has_fn_attr(llvmf, EnumAttribute("noreturn"))

if hasNoRet
return tapeV
Expand Down Expand Up @@ -1599,10 +1575,7 @@ end
fop = called_operand(orig)::LLVM.Function
for (i, v) in enumerate(operands(orig)[1:end-1])
if v == val
if !any(
a -> kind(a) == kind(StringAttribute("enzymejl_returnRoots")),
collect(parameter_attributes(fop, i)),
)
if !has_fn_attr(fop, StringAttribute("enzymejl_returnRoots"))
non_rooting_use = true
break
end
Expand Down
42 changes: 6 additions & 36 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,7 @@ include("parallelrules.jl")
if in(name, ("ijl_f_finalizer", "jl_f_finalizer"))
return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR)
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attr(F, StringAttribute("enzyme_inactive"))
return true
end
end
Expand Down Expand Up @@ -234,12 +229,7 @@ end
if in(name, ("ijl_f_finalizer", "jl_f_finalizer"))
return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR)
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attribute(F, StringAttribute("enzyme_inactive"))
return true
end
end
Expand Down Expand Up @@ -317,12 +307,7 @@ end
common_finalizer_rev(2, B, orig, gutils, tape)
return nothing
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attribute(F, StringAttribute("enzyme_inactive"))
return nothing
end
end
Expand All @@ -343,12 +328,7 @@ end
if in(name, ("ijl_invoke", "jl_invoke"))
return common_invoke_fwd(2, B, orig, gutils, normalR, shadowR)
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attribute(F, StringAttribute("enzyme_inactive"))
return true
end
end
Expand All @@ -365,12 +345,7 @@ end
if in(name, ("ijl_invoke", "jl_invoke"))
return common_invoke_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR)
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attribute(F, StringAttribute("enzyme_inactive"))
return true
end
end
Expand All @@ -388,12 +363,7 @@ end
common_invoke_rev(2, B, orig, gutils, tape)
return nothing
end
if any(
map(
k -> kind(k) == kind(StringAttribute("enzyme_inactive")),
collect(function_attributes(F)),
),
)
if has_fn_attribute(F, StringAttribute("enzyme_inactive"))
return nothing
end
end
Expand Down
7 changes: 1 addition & 6 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ end
end
end
end
hasNoRet = any(
map(
k -> kind(k) == kind(LLVM.EnumAttribute("noreturn")),
collect(LLVM.function_attributes(copysetfn)),
),
)
hasNoRet = has_fn_attr(Compiler.copysetfn, LLVM.EnumAttribute("noreturn"))
@assert !hasNoRet
if !hasNoRet
push!(LLVM.function_attributes(copysetfn), LLVM.EnumAttribute("alwaysinline", 0))
Expand Down

0 comments on commit 94862cf

Please sign in to comment.