From 78eca92056d2f0f73a7f627e0424610ded2cea91 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 8 Nov 2022 10:13:34 -0800 Subject: [PATCH] Fix rev mode shadow alloc bundles (#538) * Fix rev mode shadow alloc bundles * fix --- src/compiler.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 27d69981a1..bf9f07c8ce 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -336,7 +336,7 @@ end function array_inner(::Type{<:Array{T}}) where T return T end -function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef})::LLVM.API.LLVMValueRef +function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMValueRef, numArgs::Csize_t, Args::Ptr{LLVM.API.LLVMValueRef}, gutils::API.EnzymeGradientUtilsRef)::LLVM.API.LLVMValueRef inst = LLVM.Instruction(OrigCI) mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) ctx = LLVM.context(LLVM.Value(OrigCI)) @@ -351,11 +351,13 @@ function array_shadow_handler(B::LLVM.API.LLVMBuilderRef, OrigCI::LLVM.API.LLVMV b = LLVM.Builder(B) vals = LLVM.Value[] + valTys = API.CValueType[] for i = 1:numArgs + push!(valTys, API.VT_Primal) push!(vals, LLVM.Value(unsafe_load(Args, i))) end - anti = LLVM.call!(b, LLVM.Value(LLVM.API.LLVMGetCalledValue(OrigCI)), vals) + anti = LLVM.Value(API.EnzymeGradientUtilsCallWithInvertedBundles(gutils, LLVM.Value(LLVM.API.LLVMGetCalledValue(OrigCI)), vals, length(vals), OrigCI, valTys, length(valTys), b, #=lookup=#false )) prod = LLVM.Value(unsafe_load(Args, 2)) for i = 3:numArgs @@ -4105,17 +4107,17 @@ function __init__() end register_alloc_handler!( ("jl_alloc_array_1d", "ijl_alloc_array_1d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) ) register_alloc_handler!( ("jl_alloc_array_2d", "ijl_alloc_array_2d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) ) register_alloc_handler!( ("jl_alloc_array_3d", "ijl_alloc_array_3d"), - @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef})), + @cfunction(array_shadow_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, Csize_t, Ptr{LLVM.API.LLVMValueRef}, API.EnzymeGradientUtilsRef)), @cfunction(null_free_handler, LLVM.API.LLVMValueRef, (LLVM.API.LLVMBuilderRef, LLVM.API.LLVMValueRef, LLVM.API.LLVMValueRef)) ) register_handler!(