Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash-attention-like gpu kernel #23

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,28 @@ version = "0.2.12"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[compat]
Adapt = "3.3"
CUDA = "3, 4"
Adapt = "3.3, 4"
CUDA = "3, 4, 5"
ChainRulesCore = "1.3"
GPUArrays = "8"
GPUArrays = "8, 9, 10"
GPUArraysCore = "0.1"
NNlib = "0.7, 0.8"
NNlibCUDA = "0.2"
NNlib = "0.7, 0.8, 0.9"
Requires = "1.1"
Static = "0.7, 0.8"
julia = "1.6"
julia = "1.9"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand Down
3 changes: 2 additions & 1 deletion src/NeuralAttentionlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import GPUArraysCore
using ChainRulesCore

using NNlib
using NNlibCUDA

using Requires

Expand Down Expand Up @@ -66,4 +65,6 @@ using .Masks
using .Matmul
using .Functional

include("./flash/Flash.jl")

end # module
11 changes: 11 additions & 0 deletions src/flash/Flash.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module Flash

include("utils.jl")
include("mma.jl")
include("forward_utils.jl")
include("forward.jl")
include("backward_utils.jl")
include("backward.jl")
include("launch.jl")

end
242 changes: 242 additions & 0 deletions src/flash/backward.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
using CUDA
using CUDA: i32
using KernelAbstractions.Extras: @unroll

function create_backward_dynamic_shm(config)
(; Br, Bc) = config
(; Wk, Wn) = config
(; Dk, Dv) = config
(; computeT, reduceT) = config
D = max(Dk, Dv)
offset = 0i32
Li = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset)
offset += sizeof(Li) % Int32
Di = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset)
offset += sizeof(Di) % Int32
P = @inbounds CuDynamicSharedArray(computeT, (Bc, Br), offset)
offset += sizeof(P) % Int32
dS = @inbounds CuDynamicSharedArray(computeT, (Bc, Br), offset)
offset += sizeof(dS) % Int32
Qi = @inbounds CuDynamicSharedArray(computeT, (Dk, Br), offset)
offset += sizeof(Qi) % Int32
dQi = @inbounds CuDynamicSharedArray(reduceT, (Dk, Br), offset)
offset += sizeof(dQi) % Int32
Kj = @inbounds CuDynamicSharedArray(computeT, (Dk, Bc), offset)
offset += sizeof(Kj) % Int32
dKj = @inbounds CuDynamicSharedArray(reduceT, (Dk, Bc), offset)
offset += sizeof(dKj) % Int32
Vj = @inbounds CuDynamicSharedArray(computeT, (Dv, Bc), offset)
offset += sizeof(Vj) % Int32
dVj = @inbounds CuDynamicSharedArray(reduceT, (Dv, Bc), offset)
offset += sizeof(dVj) % Int32
dOi = @inbounds CuDynamicSharedArray(computeT, (Dv, Br), offset)
offset += sizeof(dOi) % Int32
return (; P, dS, Kj, dKj, Vj, dVj, dOi, Qi, dQi, Li, Di)
end

@inline function compute_P_and_dS(config, grps, shms, sizes)
(; Br, Bc, Wm, Wn, Wk, Dk, Dv) = config
(; computeT, reduceT) = config
(; ss, minval) = config
(; Qi, Kj, Vj, Li, dOi, P, dS, Di) = shms
(; W, warp, lane) = grps
(; sfull, klen, qlen, tr, tc, tdk, tdv) = sizes
rstop = min(Br, qlen)
cstop = min(Bc, klen)
np = tr * tc
@unroll for ti_tj = warp:W:np
ti, tj = fldmod1(ti_tj, tc)
sj = (tj - 1i32) * Wn + 1i32
si = (ti - 1i32) * Wm + 1i32
(si > rstop || sj > cstop) && continue
d = warp_load_1xm(config, lane, Di, si)
dp = warp_fill_c(config, zero(reduceT))
@unroll for tk = 1i32:tdv
sk = (tk - 1i32) * Wk + 1i32
sk > Dv && break
doi = warp_load_kxm(config, lane, dOi, sk, si)
v = warp_load_kxn(config, lane, Vj, sk, sj)
dp = warp_mma(config, lane, v, doi, dp) # dp += v' * doi
end
dpd = rbroadcast(-, dp, d)
l = warp_load_1xm(config, lane, Li, si)
s = warp_fill_c(config, zero(reduceT))
@unroll for tk = 1i32:tdk
sk = (tk - 1i32) * Wk + 1i32
sk > Dk && break
q = warp_load_kxm(config, lane, Qi, sk, si)
k = warp_load_kxn(config, lane, Kj, sk, sj)
s = warp_mma(config, lane, k, q, s) # s += k' * q
end
s = s .* ss
p = rbroadcast(-, s, l)
if !sfull
pmask = warp_gen_pmask(config, lane, klen, qlen, sj, si)
p = _elseif.(p, pmask, minval) # @. ifelse(pmask, p, minval)
end
p = exp.(p)
warp_shm_write_nxm!(config, lane, P, fragment_type(computeT, p), sj, si)
ds = p .* dpd
ds = ds .* ss
warp_shm_write_nxm!(config, lane, dS, fragment_type(computeT, ds), sj, si)
end
return nothing
end

@inline function compute_dQ_KdS(config, grps, shms, sizes)
(; Br, Bc, Wm, Wn, Wk, Dk) = config
(; computeT, reduceT) = config
(; ss) = config
(; dQi, Kj, dS) = shms
(; W, warp, lane) = grps
(; klen, qlen) = sizes
tr = cld(Br, Wk) % Int32
tc = cld(Bc, Wm) % Int32
tdk = cld(Dk, Wn) % Int32
rstop = min(Br, qlen)
cstop = min(Bc, klen)
np = tr * tdk
@unroll for ti_tk = warp:W:np
ti, tk = fldmod1(ti_tk, tdk)
si = (ti - 1i32) * Wm + 1i32
sk = (tk - 1i32) * Wn + 1i32
(si > rstop || sk > Dk) && continue
dq = warp_fill_c(config, zero(reduceT))
@unroll for tj = 1i32:tc
sj = (tj - 1i32) * Wk + 1i32
sj > cstop && break
ds = warp_load_kxm(config, lane, dS, sj, si)
k = warp_load_nxk(config, lane, Kj, sk, sj)
dq = warp_mma(config, lane, k, ds, dq) # dQi += K * dSi
end
warp_shm_write_nxm!(config, lane, dQi, dq, sk, si)
end
return nothing
end

@inline function compute_dV_dOPᵀ(config, grps, shms, sizes)
(; Br, Bc, Wm, Wn, Wk, Dv) = config
(; computeT, reduceT) = config
(; P, dVj, dOi) = shms
(; W, warp, lane) = grps
(; klen, qlen) = sizes
(; Tc) = sizes
tc = cld(Bc, Wm) % Int32
tdv = cld(Dv, Wn) % Int32
tr = cld(Br, Wk) % Int32
rstop = min(Br, qlen)
cstop = min(Bc, klen)
np = tc * tdv
@unroll for tj_tk = warp:W:np
tj, tk = fldmod1(tj_tk, tdv)
sj = (tj - 1i32) * Wm + 1i32
sk = (tk - 1i32) * Wn + 1i32
(sj > cstop || sk > Dv) && continue
dv = warp_load_nxm(config, lane, dVj, sk, sj)
@unroll for ti = 1i32:tr
si = (ti - 1i32) * Wk + 1i32
si > rstop && break
p = warp_load_mxk(config, lane, P, sj, si)
doi = warp_load_nxk(config, lane, dOi, sk, si)
dv = warp_mma(config, lane, doi, p, dv) # dvj += doi * p'
end
warp_shm_write_nxm!(config, lane, dVj, dv, sk, sj)
end
return nothing
end

@inline function compute_dK_QdSᵀ(config, grps, shms, sizes)
(; Br, Bc, Wm, Wn, Wk, Dk) = config
(; computeT, reduceT) = config
(; ss) = config
(; Qi, dKj, dS) = shms
(; W, warp, lane) = grps
(; klen, qlen) = sizes
tr = cld(Br, Wk) % Int32
tc = cld(Bc, Wm) % Int32
tdk = cld(Dk, Wn) % Int32
rstop = min(Br, qlen)
cstop = min(Bc, klen)
np = tc * tdk
@unroll for tj_tk = warp:W:np
tj, tk = fldmod1(tj_tk, tdk)
sj = (tj - 1i32) * Wm + 1i32
sk = (tk - 1i32) * Wn + 1i32
(sj > cstop || sk > Dk) && continue
dk = warp_load_nxm(config, lane, dKj, sk, sj)
@unroll for ti = 1i32:tr
si = (ti - 1i32) * Wk + 1i32
si > rstop && break
ds = warp_load_mxk(config, lane, dS, sj, si)
q = warp_load_nxk(config, lane, Qi, sk, si)
dk = warp_mma(config, lane, q, ds, dk) # dKj += Q * dSi'
end
warp_shm_write_nxm!(config, lane, dKj, dk, sk, sj)
end
return nothing
end

function flash_attention_backward_kernel!(config, dQ, dK, dV, dO, O, L, Q, K, V)
(; Br, Bc, Dk, Dv) = config
Wm = config.Wm % Int32 # WMMA size
Wn = config.Wn % Int32
Wk = config.Wk % Int32
(; computeT, reduceT) = config
(; minval, ss) = config
threads = blockDim().x
ws = warpsize()
W = fld(threads, ws)
index = threadIdx().x
warp, lane = fldmod1(index, ws)
grps = (; W, index, warp, lane)
B = size(O, 3) % Int32
Nq = size(Q, 2) % Int32
Nk = size(K, 2) % Int32
dk = size(Q, 1) % Int32
dv = size(V, 1) % Int32
Tr = cld(Nq, Br) % Int32
Tc = cld(Nk, Bc) % Int32
tr = cld(Br, Wm) % Int32
tc = cld(Bc, Wn) % Int32
tdk = cld(Dk, Wk) % Int32
tdv = cld(Dv, Wn) % Int32
sizes = (; Nq, Nk, Tr, Tc, dk, dv, tr, tc, tdk, tdv)
shms = create_backward_dynamic_shm(config)
stride = gridDim().x
bidx = blockIdx().x
NP = B * Tc
for b_j = bidx:stride:NP
b, j = fldmod1(b_j, Tc)
kfull, krange, klen = chunkrange(Bc, Nk, j)
sizes = merge(sizes, (; kfull, krange, klen))
block_glb2shm!(config, shms.Kj, K, krange, b)
block_glb2shm!(config, shms.Vj, V, krange, b)
block_shm_fill!(shms.dKj, zero(reduceT))
block_shm_fill!(shms.dVj, zero(reduceT))
block_shm_fill!(shms.Di, zero(reduceT))
sync_threads()
for i = 1i32:Tr
qfull, qrange, qlen = chunkrange(Br, Nq, i)
sfull = qfull & kfull
sizes = merge(sizes, (; sfull, qfull, qrange, qlen,))
block_glb2shm!(config, shms.dOi, dO, qrange, b)
block_glb2shm!(config, shms.Li, L, qrange, b)
block_glb2shm_rowreduce_atomic!(config, *, shms.Di, shms.dOi, O, qrange, b)
block_glb2shm!(config, shms.Qi, Q, qrange, b)
sync_threads()
compute_P_and_dS(config, grps, shms, sizes)
sync_threads()
block_shm_fill!(shms.Di, zero(reduceT))
compute_dV_dOPᵀ(config, grps, shms, sizes)
compute_dK_QdSᵀ(config, grps, shms, sizes)
compute_dQ_KdS(config, grps, shms, sizes)
sync_threads()
block_shm2glb_atomic!(config, +, dQ, shms.dQi, qrange, b)
sync_threads()
end
block_shm2glb!(config, dK, shms.dKj, krange, b)
block_shm2glb!(config, dV, shms.dVj, krange, b)
sync_threads()
end
return nothing
end
65 changes: 65 additions & 0 deletions src/flash/backward_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
@inline function block_glb2shm_reduce!(config, op, shm, dOi, glb, range, b)
D = size(shm, 1) % Int32
d = size(glb, 1) % Int32
N = length(range) % Int32
cbase = first(range) - 1i32
works = length(shm) % Int32
workload = cld(works, blockDim().x) % Int32
top = threadIdx().x * workload
base = top - workload + 1i32
for i = base:min(top, works)
c, r = fldmod1(i, D)
@inbounds if r > d || c > N
shm[r, c] = zero(eltype(shm))
else
shm[r, c] = convert(eltype(shm), op(dOi[r, c], glb[r, c + cbase, b]))
end
end
return nothing
end

@inline function block_glb2shm_rowreduce_atomic!(config, op, shm, dOi, glb, range, b)
D = size(dOi, 1) % Int32
d = size(glb, 1) % Int32
N = length(range) % Int32
cbase = first(range) - 1i32
works = length(dOi) % Int32
workload = cld(works, blockDim().x) % Int32
top = threadIdx().x * workload
base = top - workload + 1i32
stop = min(top, works)
acc = zero(eltype(shm))
for i = base:stop
c, r = fldmod1(i, D)
@inbounds if r <= d && c <= N
v = oftype(acc, op(dOi[r, c], glb[r, c + cbase, b]))
else
v = zero(acc)
end
acc += v
if r == D || i == stop
@inbounds CUDA.atomic_arrayset(shm, c, +, acc)
acc = zero(eltype(shm))
end
end
return nothing
end

function block_shm2glb_atomic!(config, op, glb, shm, range, b = 1)
# size(shm) = (D, Bx)
# size(glb) = (d, N, B)
# assume D >= d, Bx >= N
D = size(shm, 1) % Int32
d = size(glb, 1) % Int32
N = length(range) % Int32
cbase = first(range) - 1i32
works = (d * N) % Int32
workload = cld(works, blockDim().x) % Int32
top = threadIdx().x * workload
base = top - workload + 1i32
for i = base:min(top, works)
c, r = fldmod1(i, d)
@inbounds CUDA.atomic_arrayset(glb, (r, c + cbase, b), op, shm[r, c])
end
return nothing
end
Loading
Loading