Skip to content

Commit

Permalink
Fix rev mode shadow alloc bundles (#538)
Browse files Browse the repository at this point in the history
* Fix rev mode shadow alloc bundles

* fix
  • Loading branch information
wsmoses authored Nov 8, 2022
1 parent b29b7ac commit 78eca92
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ end
function array_inner(::Type{<:Array{T}}) where T
return T
end
function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef})::LLVM.API.LLVMValueRef
function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef}, gutils::API.EnzymeGradientUtilsRef)::LLVM.API.LLVMValueRef
inst = LLVM.Instruction(OrigCI)
mod = LLVM.parent(LLVM.parent(LLVM.parent(inst)))
ctx = LLVM.context(LLVM.Value(OrigCI))
Expand All @@ -351,11 +351,13 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV
b = LLVM.Builder(B)

vals = LLVM.Value[]
valTys = API.CValueType[]
for i = 1:numArgs
push!(valTys, API.VT_Primal)
push!(vals, LLVM.Value(unsafe_load(Args, i)))
end

anti = LLVM.call!(b, LLVM.Value(LLVM.API.LLVMGetCalledValue(OrigCI)), vals)
anti = LLVM.Value(API.EnzymeGradientUtilsCallWithInvertedBundles(gutils, LLVM.Value(LLVM.API.LLVMGetCalledValue(OrigCI)), vals, length(vals), OrigCI, valTys, length(valTys), b, #=lookup=#false ))

prod = LLVM.Value(unsafe_load(Args, 2))
for i = 3:numArgs
Expand Down Expand Up @@ -4105,17 +4107,17 @@ function __init__()
end
register_alloc_handler!(
("jl_alloc_array_1d", "ijl_alloc_array_1d"),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)),
@cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))
)
register_alloc_handler!(
("jl_alloc_array_2d", "ijl_alloc_array_2d"),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)),
@cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))
)
register_alloc_handler!(
("jl_alloc_array_3d", "ijl_alloc_array_3d"),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})),
@cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)),
@cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef))
)
register_handler!(
Expand Down

0 comments on commit 78eca92

Please sign in to comment.