diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 048329de1e..f461a3b0c5 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -215,6 +215,28 @@ const libjulia = Ref{Ptr{Cvoid}}(C_NULL) # List of methods to location of arg which is the mi/function, then start of args const generic_method_offsets = Dict{String, Tuple{Int,Int}}(("jl_f__apply_latest" => (2,3), "ijl_f__apply_latest" => (2,3), "jl_f__call_latest" => (2,3), "ijl_f__call_latest" => (2,3), "jl_f_invoke" => (2,3), "jl_invoke" => (1,3), "jl_apply_generic" => (1,2), "ijl_f_invoke" => (2,3), "ijl_invoke" => (1,3), "ijl_apply_generic" => (1,2))) +@inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable}) + return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing +end + +@inline function has_method(sig, world::UInt, mt::Core.Compiler.InternalMethodTable) + return has_method(sig, mt.world, nothing) +end + +@inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable) + return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world) +end + +@inline function is_inactive(tys, world::UInt, mt) + if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt) + return true + end + if has_method(Tuple{typeof(EnzymeRules.inactive_noinl), tys...}, world, mt) + return true + end + return false +end + import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) @@ -449,7 +471,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, funclib)) funclib = Base.unsafe_pointer_to_objref(rep) tys = [typeof(funclib), Vararg{Any}] - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) nofree = LLVM.EnumAttribute("nofree") @@ -486,7 +508,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end tys = flib.specTypes.parameters end - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) inactive = LLVM.StringAttribute("enzyme_inactive", "") LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive) nofree = LLVM.EnumAttribute("nofree") @@ -558,7 +580,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls) end tys = flib.specTypes.parameters end - if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table) + if is_inactive(tys, world, method_table) ofn = LLVM.parent(LLVM.parent(inst)) mod = LLVM.parent(ofn) inactive = LLVM.StringAttribute("enzyme_inactive", "") @@ -725,4 +747,4 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width @show cur, off @assert false end -end \ No newline at end of file +end diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 261aaf1f34..95365f309b 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -157,7 +157,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes) return quote args = ($(wrapped...),) - + # TODO: Annotation of return value # tt0 = Tuple{$(primtypes...)} tt′ = Tuple{$(Types...)} diff --git a/test/runtests.jl b/test/runtests.jl index c1d7d49b72..911be62d9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -720,6 +720,29 @@ end @test dweights[1] ≈ 1. end +function Valuation1(z,Ls1) + @inbounds Ls1[1] = sum(Base.inferencebarrier(z)) + return nothing +end +@testset "Active setindex!" begin + v=ones(5) + dv=zero(v) + + DV1=Float32[0] + DV2=Float32[1] + + Enzyme.autodiff(Reverse,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2)) + @test dv[1] ≈ 1. + + DV1=Float32[0] + DV2=Float32[1] + v=ones(5) + dv=zero(v) + dv[1] = 1. + Enzyme.autodiff(Forward,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2)) + @test DV2[1] ≈ 1. +end + @testset "Null init union" begin @noinline function unionret(itr, cond) if cond