Skip to content

Commit

Permalink
Speed up onehot of arrays (#1953)
Browse files Browse the repository at this point in the history
* Speed up onehot of arrays

* faster tupstack

* fix

* fix

* fix assert error

* Better and GC safe version

* gc push

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 15, 2024
1 parent 2a1f213 commit 4605716
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 8 deletions.
6 changes: 3 additions & 3 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using Enzyme
end
@inline Base.convert(::Type{StaticArray}, tpa::Enzyme.TupleArray) = convert(SArray, tpa)

@inline function Enzyme.tupstack(rows::(NTuple{N, <:StaticArrays.SArray} where N), inshape, outshape)
reshape(reduce(hcat, map(vec, rows)), Size(inshape..., outshape...))
@inline function Enzyme.tupstack(rows::Tuple{Vararg{T}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T<:StaticArrays.SArray}
reshape(reduce(hcat, map(vec, rows)), Size(outshape..., inshape...))
end

@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}) where {S, T, N, L}
Expand All @@ -19,7 +19,7 @@ end
end
end

@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}, start, endl) where {S, T, N, L}
@inline function Enzyme.onehot(x::StaticArrays.SArray{S, T, N, L}, start::Int, endl::Int) where {S, T, N, L}
ntuple(Val(endl-start+1)) do i
Base.@_inline_meta
StaticArrays.SArray{S, T, N, L}(
Expand Down
31 changes: 27 additions & 4 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,21 @@ end
nothing
end

@inline function onehot(x)
function zerosetfn(x, i::Int)
res = zero(x)
@inbounds res[i] = 1
return res
end

@inline function onehot(x::Array)
Compiler.onehot_internal(zerosetfn, x, 0, length(x))
end

@inline function onehot(x::Array, start::Int, endl::Int)
Compiler.onehot_internal(zerosetfn, x, start-1, endl-start+1)
end

@inline function onehot(x::AbstractArray)
N = length(x)
ntuple(Val(N)) do i
Base.@_inline_meta
Expand All @@ -1529,7 +1543,7 @@ end
return res
end
end
@inline function onehot(x, start, endl)
@inline function onehot(x::AbstractArray, start::Int, endl::Int)
ntuple(Val(endl - start + 1)) do i
Base.@_inline_meta
res = similar(x)
Expand Down Expand Up @@ -1852,12 +1866,21 @@ function Base.getindex(a::TupleArray, args::Vararg{Int,N}) where {N}
return a.data[start]
end

@inline function tupstack(x, inshape, outshape)
@inline function tupstack(data::Tuple{Vararg{<:Array{T}}}, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}}) where {T}
num = prod(outshape)
res = Array{T}(undef, outshape..., inshape...)
for (i, val) in enumerate(data)
Base.unsafe_copyto!(res, num*(i-1)+1, val, 1, Base.reinterpret(UInt, num))
end
res
end

@inline function tupstack(x, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}})
st = Base.stack(x)
if length(outshape) == 1
st
else
reshape(st, (inshape..., outshape...))
reshape(st, (outshape..., inshape...))
end
end

Expand Down
119 changes: 119 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9362,4 +9362,123 @@ end

include("compiler/reflection.jl")

@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array}
ir = JuliaContext() do ctx
Base.@_inline_meta

target = Compiler.DefaultCompilerTarget()
params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode)
mi = GPUCompiler.methodinstance(fn, Tuple{T, Int})
job = CompilerJob(mi, CompilerConfig(target, params; kernel = false))
mod, meta = GPUCompiler.codegen(
:llvm,
job;
optimize = false,
cleanup = false,
validate = false,
)
copysetfn = meta.entry
blk = first(blocks(copysetfn))
for inst in collect(instructions(blk))
if isa(inst, LLVM.FenceInst)
eraseInst(blk, inst)
end
if isa(inst, LLVM.CallInst)
fn = LLVM.called_operand(inst)
if isa(fn, LLVM.Function)
if LLVM.name(fn) == "julia.safepoint"
eraseInst(blk, inst)
end
end
end
end
hasNoRet = any(
map(
k -> kind(k) == kind(EnumAttribute("noreturn")),
collect(function_attributes(copysetfn)),
),
)
@assert !hasNoRet
if !hasNoRet
push!(function_attributes(copysetfn), EnumAttribute("alwaysinline", 0))
end
ity = convert(LLVMType, Int)
jlvaluet = convert(LLVMType, T; allow_boxed=true)

FT = LLVM.FunctionType(jlvaluet, LLVMType[jlvaluet, ity, ity])
llvm_f = LLVM.Function(mod, "f", FT)
push!(function_attributes(llvm_f), EnumAttribute("alwaysinline", 0))

# Check if Julia version has https://github.com/JuliaLang/julia/pull/46914
# and also https://github.com/JuliaLang/julia/pull/47076
# and also https://github.com/JuliaLang/julia/pull/48620
needs_dynamic_size_workaround = !(VERSION >= v"1.10.5")

builder = LLVM.IRBuilder()
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)
inp, lstart, len = collect(LLVM.Value, parameters(llvm_f))

boxed_count = if sizeof(Int) == sizeof(Int64)
emit_box_int64!(builder, len)
else
emit_box_int32!(builder, len)
end

tag = emit_apply_type!(builder, NTuple, (boxed_count, unsafe_to_llvm(builder, T)))

fullsize = nuwmul!(builder, len, LLVM.ConstantInt(sizeof(Int)))
obj = emit_allocobj!(builder, tag, fullsize, needs_dynamic_size_workaround)

T_int8 = LLVM.Int8Type()
LLVM.memset!(builder, obj, LLVM.ConstantInt(T_int8, 0), fullsize, 0)

alloc = pointercast!(builder, obj, LLVM.PointerType(jlvaluet, Tracked))
alloc = pointercast!(builder, alloc, LLVM.PointerType(jlvaluet, 11))

loop = BasicBlock(llvm_f, "loop")
exit = BasicBlock(llvm_f, "exit")

br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, LLVM.ConstantInt(0), len), exit, loop)

position!(builder, loop)
idx = phi!(builder, ity)

push!(LLVM.incoming(idx), (LLVM.ConstantInt(0), entry))
inc = add!(builder, idx, LLVM.ConstantInt(1))
push!(LLVM.incoming(idx), (inc, loop))
rval = add!(builder, inc, lstart)
res = call!(builder, LLVM.function_type(copysetfn), copysetfn, [inp, rval])
if !hasNoRet
gidx = gep!(builder, jlvaluet, alloc, [idx])
store!(builder, res, gidx)
emit_writebarrier!(builder, get_julia_inner_types(builder, obj, res))
end

br!(builder, icmp!(builder, LLVM.API.LLVMIntEQ, inc, len), exit, loop)


T_int32 = LLVM.Int32Type()

reinsert_gcmarker!(llvm_f)

position!(builder, exit)
ret!(builder, obj)

string(mod)
end
return quote
Base.@_inline_meta
Base.llvmcall(
($ir, "f"),
Tuple{Vararg{T}},
Tuple{T, Int, Int},
x,
startv,
lengthv
)
end
end


end
2 changes: 1 addition & 1 deletion test/sugar.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Enzyme, Test

using LinearAlgebra

mul_scalar(x, y) = x[1]*y[2] + x[2]*y[1]
mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]]
Expand Down

0 comments on commit 4605716

Please sign in to comment.