From 94862cf3aeefa5887bcdef59c779eba8a7719629 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 30 Nov 2024 19:05:05 -0500 Subject: [PATCH] cleanup --- src/absint.jl | 9 ++++++++- src/compiler/utils.jl | 28 ++++++++++++++++++--------- src/rules/customrules.jl | 37 +++++------------------------------ src/rules/llvmrules.jl | 42 ++++++---------------------------------- src/sugar.jl | 7 +------ 5 files changed, 39 insertions(+), 84 deletions(-) diff --git a/src/absint.jl b/src/absint.jl index 5a72fa0873..e7c8fc6bf4 100644 --- a/src/absint.jl +++ b/src/absint.jl @@ -737,6 +737,13 @@ function abs_typeof( return (false, nothing, nothing) end +@inline function is_zero(@nospecialize(x::LLVM.Value))::Bool + if x isa LLVM.ConstantInt + return convert(UInt, x) == 0 + end + return false +end + function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if isa(arg, ConstantExpr) ce = arg @@ -744,7 +751,7 @@ function abs_cstring(@nospecialize(arg::LLVM.Value))::Tuple{Bool,String} if opcode(ce) == LLVM.API.LLVMAddrSpaceCast || opcode(ce) == LLVM.API.LLVMBitCast || opcode(ce) == LLVM.API.LLVMIntToPtr ce = operands(ce)[1] elseif opcode(ce) == LLVM.API.LLVMGetElementPtr - if all(x -> x isa LLVM.ConstantInt && convert(UInt, x) == 0, operands(ce)[2:end]) + if all(is_zero, operands(ce)[2:end]) ce = operands(ce)[1] else break diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 402fef5780..340499006f 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -329,14 +329,13 @@ function get_pgcstack(func::LLVM.Function) end function reinsert_gcmarker!(func::LLVM.Function, @nospecialize(PB::Union{Nothing, LLVM.IRBuilder}) = nothing) - for (i, v) in enumerate(parameters(func)) - if any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(func, i)), - ), - ) - return v + for i in 1:length(LLVM.parameters(fn)) + for attr in collect(LLVM.parameter_attributes(fn, i)) + if attr isa LLVM.EnumAttribute + if kind(attr) == swiftself_kind + return parameters(fn)[i] + end + end end end @@ -366,7 +365,7 @@ end const swiftself_kind = enum_attr_kind("swiftself") function has_swiftself(fn::LLVM.Function)::Bool - for i in 1:size(LLVM.parameters(fn)) + for i in 1:length(LLVM.parameters(fn)) for attr in collect(LLVM.parameter_attributes(fn, i)) if attr isa LLVM.EnumAttribute if kind(attr) == swiftself_kind @@ -388,6 +387,17 @@ function has_fn_attr(fn::LLVM.Function, attr::LLVM.EnumAttribute)::Bool end return false end +function has_fn_attr(fn::LLVM.Function, attr::LLVM.StringAttribute)::Bool + ekind = LLVM.kind(attr) + for attr in collect(function_attributes(fn)) + if attr isa LLVM.StringAttribute + if kind(attr) == ekind + return true + end + end + end + return false +end function eraseInst(bb::LLVM.BasicBlock, @nospecialize(inst::LLVM.Instruction)) @static if isdefined(LLVM, Symbol("erase!")) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 8a397ac2e8..2bc3a5eace 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -542,14 +542,7 @@ end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(llvmf, i)), - ), - ) for i = 1:length(collect(parameters(llvmf))) - ) + swiftself = has_swiftself(llvmf) if swiftself pushfirst!(reinsert_gcmarker!(fn, B)) end @@ -596,12 +589,7 @@ end debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any( - map( - k -> kind(k) == kind(EnumAttribute("noreturn")), - collect(function_attributes(llvmf)), - ), - ) + hasNoRet = has_fn_attribute(llvmf, EnumAttribute("noreturn")) if hasNoRet return false @@ -1083,14 +1071,7 @@ function enzyme_custom_common_rev( # llvmf = nested_codegen!(mode, mod, rev_func, Tuple{argTys...}, world) # end - swiftself = any( - any( - map( - k -> kind(k) == kind(EnumAttribute("swiftself")), - collect(parameter_attributes(llvmf, i)), - ), - ) for i = 1:length(collect(parameters(llvmf))) - ) + swiftself = has_swiftself(llvmf) miRT = enzyme_custom_extract_mi(llvmf)[2] _, sret, returnRoots = get_return_info(miRT) @@ -1302,12 +1283,7 @@ function enzyme_custom_common_rev( debug_from_orig!(gutils, res, orig) callconv!(res, callconv(llvmf)) - hasNoRet = any( - map( - k -> kind(k) == kind(EnumAttribute("noreturn")), - collect(function_attributes(llvmf)), - ), - ) + hasNoRet = has_fn_attr(llvmf, EnumAttribute("noreturn")) if hasNoRet return tapeV @@ -1599,10 +1575,7 @@ end fop = called_operand(orig)::LLVM.Function for (i, v) in enumerate(operands(orig)[1:end-1]) if v == val - if !any( - a -> kind(a) == kind(StringAttribute("enzymejl_returnRoots")), - collect(parameter_attributes(fop, i)), - ) + if !has_fn_attr(fop, StringAttribute("enzymejl_returnRoots")) non_rooting_use = true break end diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index 35399d0b24..b4f396d7de 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -160,12 +160,7 @@ include("parallelrules.jl") if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_fwd(2, B, orig, gutils, normalR, shadowR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attr(F, StringAttribute("enzyme_inactive")) return true end end @@ -234,12 +229,7 @@ end if in(name, ("ijl_f_finalizer", "jl_f_finalizer")) return common_finalizer_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attribute(F, StringAttribute("enzyme_inactive")) return true end end @@ -317,12 +307,7 @@ end common_finalizer_rev(2, B, orig, gutils, tape) return nothing end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attribute(F, StringAttribute("enzyme_inactive")) return nothing end end @@ -343,12 +328,7 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_fwd(2, B, orig, gutils, normalR, shadowR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attribute(F, StringAttribute("enzyme_inactive")) return true end end @@ -365,12 +345,7 @@ end if in(name, ("ijl_invoke", "jl_invoke")) return common_invoke_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR) end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attribute(F, StringAttribute("enzyme_inactive")) return true end end @@ -388,12 +363,7 @@ end common_invoke_rev(2, B, orig, gutils, tape) return nothing end - if any( - map( - k -> kind(k) == kind(StringAttribute("enzyme_inactive")), - collect(function_attributes(F)), - ), - ) + if has_fn_attribute(F, StringAttribute("enzyme_inactive")) return nothing end end diff --git a/src/sugar.jl b/src/sugar.jl index ba53f46a00..5ae59dd32d 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -40,12 +40,7 @@ end end end end - hasNoRet = any( - map( - k -> kind(k) == kind(LLVM.EnumAttribute("noreturn")), - collect(LLVM.function_attributes(copysetfn)), - ), - ) + hasNoRet = has_fn_attr(Compiler.copysetfn, LLVM.EnumAttribute("noreturn")) @assert !hasNoRet if !hasNoRet push!(LLVM.function_attributes(copysetfn), LLVM.EnumAttribute("alwaysinline", 0))