Skip to content

Commit

Permalink
Fix unused (#2228)
Browse files Browse the repository at this point in the history
* Fix unused

* Update Project.toml
  • Loading branch information
wsmoses authored Dec 27, 2024
1 parent 9de3274 commit 02d4283
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.8"
Enzyme_jll = "0.0.170"
Enzyme_jll = "0.0.171"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
Expand Down
7 changes: 4 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,8 @@ function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType),
end
end

function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, Orig::LLVM.API.LLVMValueRef, idx::UInt64, prev::API.LLVMValueRef)
function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, Orig::LLVM.API.LLVMValueRef, idx::UInt64, prev::API.LLVMValueRef, used::UInt8)
used = used != 0
V = LLVM.CallInst(V)
gutils = GradientUtils(gutils)
mode = get_mode(gutils)
Expand All @@ -681,7 +682,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
end
end

if mode == API.DEM_ForwardMode
if mode == API.DEM_ForwardMode && (used || idx != 0)
# Zero any jlvalue_t inner elements of preceeding allocation.
# Specifically in forward mode, you will first run the original allocation,
# then all shadow allocations. These allocations will thus all run before
Expand Down Expand Up @@ -1224,7 +1225,7 @@ function __init__()
@cfunction(
shadow_alloc_rewrite,
Cvoid,
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt64, LLVM.API.LLVMValueRef)
(LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt64, LLVM.API.LLVMValueRef, UInt8)
)
)
register_alloc_rules()
Expand Down

0 comments on commit 02d4283

Please sign in to comment.