Skip to content

Commit

Permalink
Generic memory slice (#2234)
Browse files Browse the repository at this point in the history
* Generic memory slice

* fix
  • Loading branch information
wsmoses authored Dec 28, 2024
1 parent f46e44d commit df85909
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/llvm/attributes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ const nofreefns = Set{String}((
"jl_array_ptr_copy",
"ijl_array_copy",
"jl_array_copy",
"ijl_genericmemory_slice",
"jl_genericmemory_slice",
"ijl_genericmemory_copy_slice",
"jl_genericmemory_copy_slice",
"ijl_get_nth_field_checked",
Expand Down Expand Up @@ -644,6 +646,8 @@ function annotate!(mod::LLVM.Module)
"ijl_alloc_array_3d",
"jl_array_copy",
"ijl_array_copy",
"jl_genericmemory_slice",
"ijl_genericmemory_slice",
"jl_genericmemory_copy_slice",
"ijl_genericmemory_copy_slice",
"jl_alloc_genericmemory",
Expand All @@ -670,8 +674,11 @@ function annotate!(mod::LLVM.Module)
LLVM.EnumAttribute("inaccessiblememonly")
else
if fname in (
"jl_genericmemory_slice",
"ijl_genericmemory_slice",
"jl_genericmemory_copy_slice",
"ijl_genericmemory_copy_slice",)
"ijl_genericmemory_copy_slice",
)
EnumAttribute(
"memory",
MemoryEffect(
Expand Down
88 changes: 88 additions & 0 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,88 @@ end
return nothing
end


@register_fwd function genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR)
ctx = LLVM.context(orig)

if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
return true
end

origops = LLVM.operands(orig)

width = get_width(gutils)

shadowin = invert_pointer(gutils, origops[1], B)
shadowdata = invert_pointer(gutils, origops[2], B)
len = new_from_original(gutils, origops[3])

i8 = LLVM.IntType(8)
algn = 0

shadowres =
UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
for idx = 1:width
ev = if width == 1
shadowin
else
extract_value!(B, shadowin, idx - 1)
end
ev2 = if width == 1
shadowdata
else
extract_value!(B, shadowdata, idx - 1)
end
callv = call_samefunc_with_inverted_bundles!(
B,
gutils,
orig,
[ev, ev2, len],
[API.VT_Shadow, API.VT_Shadow, API.VT_Primal],
false,
) #=lookup=#
if is_constant_value(gutils, origops[1])
emit_error(B, orig, "ijl_genericmemory_slice memory argument (1st arg) was constant but return was active")
end
if is_constant_value(gutils, origops[2])
emit_error(B, orig, "ijl_genericmemory_slice ptr argument (2nd arg) was constant but return was active")
end
if get_runtime_activity(gutils)
prev = new_from_original(gutils, orig)
callv = LLVM.select!(
B,
LLVM.icmp!(
B,
LLVM.API.LLVMIntNE,
ev,
new_from_original(gutils, origops[1]),
),
callv,
prev,
)
if idx == 1
API.moveBefore(prev, callv, B)
end
end
shadowres = if width == 1
callv
else
insert_value!(B, shadowres, callv, idx - 1)
end
end

unsafe_store!(shadowR, shadowres.ref)
return false
end

@register_aug function genericmemory_slice_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
return genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR)
end

@register_rev function genericmemory_slice_rev(B, orig, gutils, tape)
return nothing
end

@register_fwd function arrayreshape_fwd(B, orig, gutils, normalR, shadowR)
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
return true
Expand Down Expand Up @@ -2111,6 +2193,12 @@ end
@revfunc(genericmemory_copy_slice_rev),
@fwdfunc(genericmemory_copy_slice_fwd),
)
register_handler!(
("jl_genericmemory_slice", "ijl_genericmemory_slice"),
@augfunc(genericmemory_slice_augfwd),
@revfunc(genericmemory_slice_rev),
@fwdfunc(genericmemory_slice_fwd),
)
register_handler!(
("jl_reshape_array", "ijl_reshape_array"),
@augfunc(arrayreshape_augfwd),
Expand Down

0 comments on commit df85909

Please sign in to comment.