Skip to content

Commit

Permalink
Setindex fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 30, 2023
1 parent 7f25f04 commit 86468df
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
30 changes: 26 additions & 4 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,28 @@ const libjulia = Ref{Ptr{Cvoid}}(C_NULL)
# List of methods to location of arg which is the mi/function, then start of args
const generic_method_offsets = Dict{String, Tuple{Int,Int}}(("jl_f__apply_latest" => (2,3), "ijl_f__apply_latest" => (2,3), "jl_f__call_latest" => (2,3), "ijl_f__call_latest" => (2,3), "jl_f_invoke" => (2,3), "jl_invoke" => (1,3), "jl_apply_generic" => (1,2), "ijl_f_invoke" => (2,3), "ijl_invoke" => (1,3), "ijl_apply_generic" => (1,2)))

@inline function has_method(sig, world::UInt, mt::Union{Nothing,Core.MethodTable})
return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing
end

@inline function has_method(sig, world::UInt, mt::Core.Compiler.InternalMethodTable)
return has_method(sig, mt.world, nothing)
end

@inline function has_method(sig, world::UInt, mt::Core.Compiler.OverlayMethodTable)
return has_method(sig, mt.mt, mt.world) || has_method(sig, nothing, mt.world)
end

@inline function is_inactive(tys, world::UInt, mt)
if has_method(Tuple{typeof(EnzymeRules.inactive), tys...}, world, mt)
return true
end
if has_method(Tuple{typeof(EnzymeRules.inactive_noinl), tys...}, world, mt)
return true
end
return false
end

import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION
import GPUCompiler: backtrace, isintrinsic
function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
Expand Down Expand Up @@ -449,7 +471,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
rep = reinterpret(Ptr{Cvoid}, convert(Csize_t, funclib))
funclib = Base.unsafe_pointer_to_objref(rep)
tys = [typeof(funclib), Vararg{Any}]
if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table)
if is_inactive(tys, world, method_table)
inactive = LLVM.StringAttribute("enzyme_inactive", "")
LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive)
nofree = LLVM.EnumAttribute("nofree")
Expand Down Expand Up @@ -486,7 +508,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
end
tys = flib.specTypes.parameters
end
if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table)
if is_inactive(tys, world, method_table)
inactive = LLVM.StringAttribute("enzyme_inactive", "")
LLVM.API.LLVMAddCallSiteAttribute(inst, LLVM.API.LLVMAttributeFunctionIndex, inactive)
nofree = LLVM.EnumAttribute("nofree")
Expand Down Expand Up @@ -558,7 +580,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst, calls)
end
tys = flib.specTypes.parameters
end
if EnzymeRules.is_inactive_from_sig(Tuple{tys...}; world, method_table) || EnzymeRules.is_inactive_noinl_from_sig(Tuple{tys...}; world, method_table)
if is_inactive(tys, world, method_table)
ofn = LLVM.parent(LLVM.parent(inst))
mod = LLVM.parent(ofn)
inactive = LLVM.StringAttribute("enzyme_inactive", "")
Expand Down Expand Up @@ -725,4 +747,4 @@ function rewrite_union_returns_as_ref(enzymefn::LLVM.Function, off, world, width
@show cur, off
@assert false
end
end
end
2 changes: 1 addition & 1 deletion src/rules/jitrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function body_runtime_generic_augfwd(N, Width, wrapped, primttypes)

return quote
args = ($(wrapped...),)

# TODO: Annotation of return value
# tt0 = Tuple{$(primtypes...)}
tt′ = Tuple{$(Types...)}
Expand Down
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,29 @@ end
@test dweights[1] 1.
end

function Valuation1(z,Ls1)
@inbounds Ls1[1] = sum(Base.inferencebarrier(z))
return nothing
end
@testset "Active setindex!" begin
v=ones(5)
dv=zero(v)

DV1=Float32[0]
DV2=Float32[1]

Enzyme.autodiff(Reverse,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2))
@test dv[1] 1.

DV1=Float32[0]
DV2=Float32[1]
v=ones(5)
dv=zero(v)
dv[1] = 1.
Enzyme.autodiff(Forward,Valuation1,Duplicated(v,dv),Duplicated(DV1,DV2))
@test DV2[1] 1.
end

@testset "Null init union" begin
@noinline function unionret(itr, cond)
if cond
Expand Down

0 comments on commit 86468df

Please sign in to comment.