From 8b78f5fab3d1615b7595bc8770941435e4488975 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Sun, 21 Jan 2024 16:27:45 +0800 Subject: [PATCH 1/2] init forward impl refine forward --- Project.toml | 15 ++- src/NeuralAttentionlib.jl | 3 +- src/flash/Flash.jl | 9 ++ src/flash/forward.jl | 219 ++++++++++++++++++++++++++++++++ src/flash/forward_utils.jl | 253 +++++++++++++++++++++++++++++++++++++ src/flash/launch.jl | 251 ++++++++++++++++++++++++++++++++++++ src/flash/mma.jl | 168 ++++++++++++++++++++++++ src/flash/utils.jl | 141 +++++++++++++++++++++ 8 files changed, 1051 insertions(+), 8 deletions(-) create mode 100644 src/flash/Flash.jl create mode 100644 src/flash/forward.jl create mode 100644 src/flash/forward_utils.jl create mode 100644 src/flash/launch.jl create mode 100644 src/flash/mma.jl create mode 100644 src/flash/utils.jl diff --git a/Project.toml b/Project.toml index 93c5d95..bc0d418 100644 --- a/Project.toml +++ b/Project.toml @@ -5,27 +5,28 @@ version = "0.2.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SIMD = "fdea26ae-647d-5447-a871-4b548cad5224" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [compat] -Adapt = "3.3" -CUDA = "3, 4" +Adapt = "3.3, 4" +CUDA = "3, 4, 5" ChainRulesCore = "1.3" -GPUArrays = "8" +GPUArrays = "8, 9, 10" GPUArraysCore = "0.1" -NNlib = "0.7, 0.8" -NNlibCUDA = "0.2" +NNlib = "0.7, 0.8, 0.9" Requires = "1.1" Static = "0.7, 0.8" -julia = "1.6" +julia = "1.9" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" diff --git a/src/NeuralAttentionlib.jl b/src/NeuralAttentionlib.jl index 342ee4c..3b7eab6 100644 --- a/src/NeuralAttentionlib.jl +++ b/src/NeuralAttentionlib.jl @@ -9,7 +9,6 @@ import GPUArraysCore using ChainRulesCore using NNlib -using NNlibCUDA using Requires @@ -66,4 +65,6 @@ using .Masks using .Matmul using .Functional +include("./flash/Flash.jl") + end # module diff --git a/src/flash/Flash.jl b/src/flash/Flash.jl new file mode 100644 index 0000000..eb487cb --- /dev/null +++ b/src/flash/Flash.jl @@ -0,0 +1,9 @@ +module Flash + +include("utils.jl") +include("mma.jl") +include("forward_utils.jl") +include("forward.jl") +include("launch.jl") + +end diff --git a/src/flash/forward.jl b/src/flash/forward.jl new file mode 100644 index 0000000..47c92b0 --- /dev/null +++ b/src/flash/forward.jl @@ -0,0 +1,219 @@ +using CUDA +using CUDA: i32 +using KernelAbstractions.Extras: @unroll + +@inline function create_forward_dynamic_shm(config) + (; Br, Bc) = config + (; Wk, Wn) = config + (; Dk, Dv) = config + (; computeT, reduceT) = config + D = max(Dk, Dv) + offset = 0i32 + S = @inbounds CuDynamicSharedArray(reduceT, (Bc, Br), offset) + offset += sizeof(S) % Int32 + P = @inbounds CuDynamicSharedArray(computeT, (Bc, Br), offset) + offset += sizeof(P) % Int32 + pi = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(pi) % Int32 + li = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(li) % Int32 + li2 = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(li2) % Int32 + mi = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(mi) % Int32 + mi2 = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(mi2) % Int32 + Qi = @inbounds CuDynamicSharedArray(computeT, (Dk, Br), offset) + offset += sizeof(Qi) % Int32 + KVj = @inbounds CuDynamicSharedArray(computeT, (D, Bc), offset) + offset += sizeof(KVj) % Int32 + Oi = @inbounds CuDynamicSharedArray(reduceT, (Dv, Br), offset) + return (; Qi, KVj, Oi, S, P, pi, li, li2, mi, mi2) +end + +@inline function compute_KᵀQ(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dk) = config + (; computeT, reduceT) = config + (; ss) = config + (; Qi, KVj, S, mi, mi2) = shms + (; W, warp, lane) = grps + (; klen, qlen, tr, tc, tdk) = sizes + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tr * tc + @unroll for ti_tj = warp:W:np + ti, tj = fldmod1(ti_tj, tc) + si = (ti - 1i32) * Wm + 1i32 + sj = (tj - 1i32) * Wn + 1i32 + (si > rstop || sj > cstop) && continue + s = warp_fill_c(config, zero(reduceT)) + @unroll for tk = 1i32:tdk + sk = (tk - 1i32) * Wk + 1i32 + sk > Dk && break + q = warp_load_kxm(config, lane, Qi, sk, si) + k = warp_load_kxn(config, lane, KVj, sk, sj) + s = warp_mma(config, lane, k, q, s) # s += k' * q + end + s = s .* ss + warp_shm_write_nxm!(config, lane, S, s, sj, si) + m = warp_reducerow_nxm(config, max, s) # 8 -> 4 + warp_shm_reduce_1xm_atomic!(config, lane, max, mi2, m, si) + end + return nothing +end + +@inline function compute_exp_S(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn) = config + (; computeT, reduceT) = config + (; minval) = config + (; S, P, pi, mi2) = shms + (; W, warp, lane) = grps + (; sfull, klen, qlen, tr, tc, tdk) = sizes + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tr * tc + @unroll for ti_tj = warp:W:np + ti, tj = fldmod1(ti_tj, tc) + si = (ti - 1i32) * Wm + 1i32 + sj = (tj - 1i32) * Wn + 1i32 + (si > rstop || sj > cstop) && continue + m = warp_load_1xm(config, lane, mi2, si) + p = warp_load_nxm(config, lane, S, sj, si) + p = rbroadcast(-, p, m) + if !sfull + pmask = warp_gen_pmask(config, lane, klen, qlen, sj, si) + p = _elseif.(p, pmask, minval) # @. ifelse(pmask, p, minval) + end + p = exp.(p) # @. exp(p - m) + ps = warp_reducerow_nxm(config, +, p) # 8 -> 4 + p0 = fragment_type(computeT, p) + warp_shm_write_nxm!(config, lane, P, p0, sj, si) + warp_shm_reduce_1xm_atomic!(config, lane, +, pi, ps, si) + end + return nothing +end + +@inline function compute_exp_m_O_VP(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dv) = config + (; computeT, reduceT) = config + (; mi, mi2, pi, li, li2, P, KVj, Oi) = shms + (; W, warp, lane) = grps + (; klen, qlen, tr, tc, tdv) = sizes + (; j) = grps + (; Tc) = sizes + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tr * tdv + is_last = j == Tc + @unroll for ti_tk = warp:W:np + ti, tk = fldmod1(ti_tk, tdv) + si = (ti - 1i32) * Wm + 1i32 + sk = (tk - 1i32) * Wn + 1i32 + (si > rstop || sk > Dv) && continue + mp = warp_load_1xm(config, lane, mi, si) + m = warp_load_1xm(config, lane, mi2, si) + ps = warp_load_1xm(config, lane, pi, si) + l = warp_load_1xm(config, lane, li, si) + mdiff = mp .- m + em = exp.(mdiff) #@. exp(mp - m) + l = em .* l + l = l .+ ps # @. em * l + ps + o = warp_load_nxm(config, lane, Oi, sk, si) + o = rbroadcast(*, em, o) # @. em * o + if is_last + m0 = CUDA.log.(l) + m = m .+ m0 + warp_shm_write_1xm!(config, lane, li2, m, si) + l = inv.(l) + @unroll for tj = 1i32:tc + sj = (tj - 1i32) * Wk + 1i32 + sj > cstop && break + p = warp_load_kxm(config, lane, P, sj, si) + v = warp_load_nxk(config, lane, KVj, sk, sj) + o = warp_mma(config, lane, v, p, o) # o += v * p + end + o = rbroadcast(*, l, o) # @. l * o + else + warp_shm_write_1xm!(config, lane, li2, l, si) + @unroll for tj = 1i32:tc + sj = (tj - 1i32) * Wk + 1i32 + sj > cstop && break + p = warp_load_kxm(config, lane, P, sj, si) + v = warp_load_nxk(config, lane, KVj, sk, sj) + o = warp_mma(config, lane, v, p, o) # o += v * p + end + end + warp_shm_write_nxm!(config, lane, Oi, o, sk, si) + end + return nothing +end + +function flash_attention_forward_kernel!(config, O, L, Q, K, V) + (; Br, Bc, Dk, Dv) = config # share memory size + Wm = config.Wm % Int32 # WMMA size + Wn = config.Wn % Int32 + Wk = config.Wk % Int32 + (; computeT, reduceT) = config + (; minval) = config + # warp groups + threads = blockDim().x + ws = warpsize() + W = fld(threads, ws) + index = threadIdx().x + warp, lane = fldmod1(index, ws) + grps = (; W, index, warp, lane) + # chunks + B = size(O, 3) % Int32 + Nq = size(Q, 2) % Int32 + Nk = size(K, 2) % Int32 + dk = size(Q, 1) % Int32 + dv = size(V, 1) % Int32 + Tr = cld(Nq, Br) % Int32 + Tc = cld(Nk, Bc) % Int32 + tr = cld(Br, Wm) % Int32 + tc = cld(Bc, Wn) % Int32 + tdk = cld(Dk, Wk) % Int32 + tdv = cld(Dv, Wn) % Int32 + sizes = (; Nq, Nk, Tr, Tc, dk, dv, tr, tc, tdk, tdv) + # allocs shms + shms = create_forward_dynamic_shm(config) + # batch loop + stride = gridDim().x + bidx = blockIdx().x + NP = B * Tr + for b_i = bidx:stride:NP + b, i = fldmod1(b_i, Tr) + qfull, qrange, qlen = chunkrange(Br, Nq, i) + sizes = merge(sizes, (; qfull, qrange, qlen,)) + block_glb2shm!(config, shms.Qi, Q, qrange, b) + block_shm_fill!(shms.mi, minval) + block_shm_fill!(shms.mi2, minval) + block_shm_fill!(shms.li, zero(reduceT)) + block_shm_fill!(shms.li2, zero(reduceT)) + block_shm_fill!(shms.Oi, zero(reduceT)) + for j = 1i32:Tc + grps = merge(grps, (; j)) + kfull, krange, klen = chunkrange(Bc, Nk, j) + sfull = qfull & kfull + sizes = merge(sizes, (; kfull, krange, klen, sfull)) + block_glb2shm!(config, shms.KVj, K, krange, b) + block_shm_fill!(shms.pi, zero(reduceT)) + sync_threads() # Q, K, S + # S = K^T * Q * dk^-1/2 + compute_KᵀQ(config, grps, shms, sizes) + sync_threads() # S, m + # P = exp(S - m) + block_glb2shm!(config, shms.KVj, V, krange, b) + compute_exp_S(config, grps, shms, sizes) + sync_threads() # P, pi, V + # O = exp(mp - m) * O + V * P + # O *= l + compute_exp_m_O_VP(config, grps, shms, sizes) + shms = merge(shms, (; mi = shms.mi2, mi2 = shms.mi, li = shms.li2, li2 = shms.li)) + sync_threads() + end # Tc loop + block_shm2glb!(config, O, shms.Oi, qrange, b) + block_shm2glb!(config, L, shms.li, qrange, b) + end + return nothing +end diff --git a/src/flash/forward_utils.jl b/src/flash/forward_utils.jl new file mode 100644 index 0000000..2ddbdfb --- /dev/null +++ b/src/flash/forward_utils.jl @@ -0,0 +1,253 @@ +using CUDA +using CUDA: i32 +using KernelAbstractions.Extras: @unroll + +function block_shm_fill!(shm, v) + D = size(shm, 1) % Int32 + works = length(shm) + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + for i = base:min(top, works) + c, r = fldmod1(i, D) + @inbounds shm[r, c] = v + end + return nothing +end + +function block_glb2shm!(config, shm, glb, range, b = 1) + # size(shm) = (D, Bx) + # size(glb) = (d, N, B) + # assume D >= d, Bx >= N + D = size(shm, 1) % Int32 + d = size(glb, 1) % Int32 + N = length(range) % Int32 + cbase = first(range) - 1i32 + works = length(shm) % Int32 + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + for i = base:min(top, works) + c, r = fldmod1(i, D) + @inbounds if r > d || c > N + shm[r, c] = zero(eltype(shm)) + else + shm[r, c] = convert(eltype(shm), glb[r, c + cbase, b]) + end + end + return nothing +end + +function block_shm2glb!(config, glb, shm, range, b = 1) + # size(shm) = (D, Bx) + # size(glb) = (d, N, B) + # assume D >= d, Bx >= N + D = size(shm, 1) % Int32 + d = size(glb, 1) % Int32 + N = length(range) % Int32 + cbase = first(range) - 1i32 + works = (d * N) % Int32 + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + for i = base:min(top, works) + c, r = fldmod1(i, d) + @inbounds glb[r, c + cbase, b] = shm[r, c] + end + return nothing +end + +function chunkrange(B, N, i) + stop = i * B + start0 = stop - B + start = start0 + 1i32 + if stop > N + stop = N + len = N - start0 + full = false + else + len = B + full = true + end + return (full, start:stop, len) +end + +_elseif(a, p, b) = ifelse(p, a, b) + +@inline function _warp_gen_pmask( + ::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, lane, + klen, qlen, sj, si) where {MT, T} + r, c = _fast_fldmod(lane - 1i32, Val(4)) + vs = ((sj + r + Vec{8, Int32}((0i32, 0i32, 8i32, 8i32, 0i32, 0i32, 8i32, 8i32))) <= klen) & + ((si + c + Vec{8, Int32}((0i32, 1i32, 0i32, 1i32, 8i32, 9i32, 8i32, 9i32))) <= qlen) + return Fragment{16, 16, 16, 8, Bool, Unspecified, Accumulator}(WMMA.flatten(vs.data)) +end + +@inline function warp_gen_pmask(config, lane, klen, qlen, sj, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return _warp_gen_pmask(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, lane, klen, qlen, sj, si) +end + +@inline function warp_shm_reduce_nxm_atomic!(config, lane, op, mem, frag, r, c) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + @inbounds fragment_reduce_store_d(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, op, mem, frag, r, c) + return nothing +end + +@inline function _warp_reduce_1xm!( + ::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, lane, + op, mi, m::Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}, si +) where {MT, T} + shflmask = typemax(UInt32) + grp = _fast_mod(lane - 1i32, Val(4)) + vs = m.x + vs = op.(vs, _shfl_xor_sync(shflmask, vs, 4i32)) + vs = op.(vs, _shfl_xor_sync(shflmask, vs, 8i32)) + vs = op.(vs, _shfl_xor_sync(shflmask, vs, 16i32)) + @inbounds if lane <= 4 + # m = [(1, 1) (1, 2) (9, 1) (9, 2) (1, 9) (1, 10) (9, 9) (9, 10)] + idx = si + _fast_mul(grp, Val(2)) + Vec{4, Int32}((0i32, 1i32, 8i32, 9i32)) + Base.Cartesian.@nexprs 4 i -> CUDA.atomic_arrayset(mi, idx[i], op, vs[i]) + end + return nothing +end + +@inline function warp_shm_reduce_1xm_atomic!(config, lane, op, mi, m, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + _warp_reduce_1xm!(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, lane, op, mi, m, si) + return nothing +end + +@inline function _warp_reducerow_nxm( + ::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, + op, + frag::Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}, + acc::Union{Nothing, Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}} = nothing +) where {MT, T} + vs = frag.x + if isnothing(acc) + v1 = @inbounds op(vs[1], vs[3]) + v2 = @inbounds op(vs[2], vs[4]) + v3 = @inbounds op(vs[5], vs[7]) + v4 = @inbounds op(vs[6], vs[8]) + else + a = acc.x + v1 = @inbounds op(op(a[1], vs[1]), vs[3]) + v2 = @inbounds op(op(a[2], vs[2]), vs[4]) + v3 = @inbounds op(op(a[3], vs[5]), vs[7]) + v4 = @inbounds op(op(a[4], vs[6]), vs[8]) + end + return Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}((v1, v2, v3, v4)) +end + +@inline function warp_reducerow_nxm(config, op, frag, acc = nothing) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return _warp_reducerow_nxm(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, op, frag, acc) +end + +@inline function _warp_load_1xm( + ::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, + lane, mi, si +) where {MT, T} + grp = _fast_mod(lane - 1i32, Val(4)) + if lane <= 4 + indices = si + _fast_mul(grp, Val(2)) + Vec{4, Int32}((0i32, 1i32, 8i32, 9i32)) + vs = @inbounds Base.Cartesian.@ntuple 4 i -> VecElement(mi[indices[i]]) + else + vs = (VecElement(zero(T)), VecElement(zero(T)), VecElement(zero(T)), VecElement(zero(T))) + end + vs = _shfl_sync(typemax(UInt32), vs, grp + 1i32) + return Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}(WMMA.flatten(vs)) +end + +@inline function warp_load_1xm(config, lane, mi, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return _warp_load_1xm(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, lane, mi, si) +end + +@inline function warp_fill_c(config, value) + (; Wm, Wn, Wk) = config + (; reduceT) = config + return WMMA.fill_c(value, Config{Wn, Wm, Wk, reduceT}) +end + +@inline function _warp_fill_reduce_c(::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, value) where {MT, T} + v = convert(T, value) + return Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}(ntuple(_->v, Val(4))) +end +@inline function warp_fill_reduce_c(config, value) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return _warp_fill_reduce_c(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, value) +end + +@inline function warp_load_kxm(config, lane, Qi, sk, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_load_b(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, Qi, sk, si) +end +@inline function warp_load_kxn(config, lane, Kj, sk, sj) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_load_a(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, RowMajor, lane, Kj, sk, sj) +end +@inline function warp_load_nxk(config, lane, Vj, sk, sj) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_load_a(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, Vj, sk, sj) +end +@inline function warp_load_nxm(config, lane, Oi, sk, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_load_c(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, Oi, sk, si) +end +@inline function warp_mma(config, lane, a, b, c) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_mma(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, lane, a, b, c) +end +@inline function warp_shm_write_nxm!(config, lane, Oi, o, sj, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_store_d(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, Oi, o, sj, si) +end +@inline function warp_shm_write_1xm!(config, lane, li, l, si) + grp = _fast_mod(lane - 1i32, Val(4)) + if lane <= 4 + indices = si + _fast_mul(grp, Val(2)) + Vec{4, Int32}((0i32, 1i32, 8i32, 9i32)) + @unroll for i = Base.OneTo{Int32}(length(indices)) + @inbounds li[indices[i]] = l[i] + end + end + return nothing +end + +@inline function rbroadcast( + f, + frag::Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}, + acc::Fragment{16, 16, 16, 4, T, Unspecified, Accumulator} +) where T + a, b, c, d = tuplesplitp(frag.x, Val(4)) + x, y = tuplesplitp(acc.x, Val(2)) + vs = tuplejoin(f.(a, x), f.(b, x), f.(c, y), f.(d, y)) + return Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}(vs) +end +@inline function rbroadcast( + f, + acc::Fragment{16, 16, 16, 4, T, Unspecified, Accumulator}, + frag::Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}, +) where T + a, b, c, d = tuplesplitp(frag.x, Val(4)) + x, y = tuplesplitp(acc.x, Val(2)) + vs = tuplejoin(f.(x, a), f.(x, b), f.(y, c), f.(y, d)) + return Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}(vs) +end + +@inline function fragment_type(::Type{T}, frag::Fragment{M, N, K, E, T0, L, U}) where {T, M, N, K, E, T0, L, U} + return Fragment{M, N, K, E, T, L, U}(frag.x) +end diff --git a/src/flash/launch.jl b/src/flash/launch.jl new file mode 100644 index 0000000..12f2ad0 --- /dev/null +++ b/src/flash/launch.jl @@ -0,0 +1,251 @@ +const allow_optin_shmem = Ref{Union{Bool, Nothing}}(nothing) +optin_shmem() = allow_optin_shmem[] = true +optout_shmem() = allow_optin_shmem[] = false +const max_shmem_sizes = Dict{CuDevice, NTuple{2, Int32}}() +get_max_shmem_default(dev) = attribute(dev, CUDA.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK) +get_max_shmem_possible(dev) = attribute(dev, CUDA.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN) +function get_max_shmem(dev) + max_shmems = get(max_shmem_sizes, dev, nothing) + _optin = allow_optin_shmem[] + optin_set = !isnothing(_optin) + optin = optin_set ? _optin : true + if !isnothing(max_shmems) + shmem_max, shmem_max_possible = max_shmems + if optin_set && !optin + shmem_max_possible = shmem_max + end + return shmem_max, shmem_max_possible + else + support_optin = capability(dev) >= v"7" + shmem_max = get_max_shmem_default(dev) + if optin + shmem_max_possible = if support_optin + try + get_max_shmem_possible(dev) + catch + @error "error occurs when querying the max size of opt-in shared memory, default max size is used." + shmem_max + end + else + optin_set && @warn "The current device does not support opt-in shared memory, default max size is used." + shmem_max + end + else + shmem_max_possible = shmem_max + end + max_shmems = (shmem_max, shmem_max_possible) + max_shmem_sizes[dev] = max_shmems + return max_shmems + end +end + +struct FlashAttenConfig{MMAConfig, NT<:NamedTuple, Options} + fields::NT +end +@inline function Base.getproperty( + config::FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}, + sym::Symbol +) where {Wm, Wn, Wk, computeT, reduceT} + if sym == :Wm + return Wm + elseif sym == :Wn + return Wn + elseif sym == :Wk + return Wk + elseif sym == :computeT + return computeT + elseif sym == :reduceT + return reduceT + else + return getfield(getfield(config, :fields), sym) + end +end +@inline function Base.hasproperty( + config::FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}, + sym::Symbol +) where {Wm, Wn, Wk, computeT, reduceT} + if sym == :Wm || sym == :Wn || sym == :Wk || sym == :computeT || sym == :reduceT + return true + else + return hasproperty(getfield(config, :fields), sym) + end +end + +FlashAttenConfig{M}(kws::NamedTuple) where M = FlashAttenConfig{M, Union{}}(kws) +FlashAttenConfig{M, U}(kws::NamedTuple) where {M, U} = FlashAttenConfig{M, typeof(kws), U}(kws) + +struct FlashAttenKernelConfig{static_config, DC <: FlashAttenConfig} + dynamic_config::DC +end +@inline function FlashAttenKernelConfig(static_config, dynamic_config) + return FlashAttenKernelConfig{static_config, typeof(dynamic_config)}(dynamic_config) +end +@inline function Base.getproperty( + config::FlashAttenKernelConfig{static_config}, + sym::Symbol +) where {static_config} + if hasproperty(static_config, sym) + return getproperty(static_config, sym) + else + return getproperty(getfield(config, :dynamic_config), sym) + end +end +@inline function Base.hasproperty(config::FlashAttenKernelConfig{sconfig}, sym::Symbol) where sconfig + return hasproperty(sconfig, sym) || hasproperty(getfield(config, dynamic_config), sym) +end + +function find_d_dims(Wm, Wn, Wk, dk, dv) + Dk = cld(dk, Wk) * Wk + Dv = cld(dv, Wn) * Wn + D = max(Dk, Dv) + return (; Dk, Dv, D) +end +find_d_dims(::Type{<:FlashAttenConfig{<:MMAConfig{Wm, Wn, Wk}}}, dk, dv) where {Wm, Wn, Wk} = find_d_dims(Wm, Wn, Wk, dk, dv) + +use_fastmath() = CUDA.default_math_mode[] == CUDA.FAST_MATH +function get_compute_precision(reduceT) + cuda_math_prec_sym = CUDA.default_math_precision[] + if isnothing(cuda_math_prec_sym) + computeT = reduceT + elseif cuda_math_prec_sym == :TensorFloat32 + computeT = Float32 + elseif cuda_math_prec_sym == :BFloat16 + computeT = BFloat16 + elseif cuda_math_prec_sym == :Float16 + computeT = Float16 + else + @debug "Unknown precision symbol: $cuda_math_prec_sym\nUse eltype(Q) = $reduceT" + computeT = reduceT + end + return computeT +end + +function total_fw_shm_size(config) + @assert(hasproperty(config, :Br) && hasproperty(config, :Bc) && hasproperty(config, :Dk) && hasproperty(config, :Dv), + "Cannot compute shmem size without knowing the size of inputs") + (; Br, Bc, Dk, Dv) = config + (; computeT, reduceT) = config + return sizeof(computeT) * (Dk * Br + max(Dk, Dv) * Bc + Br * Bc) + sizeof(reduceT) * (Br * (Bc + Dv + 5)) +end + +function find_fw_max_size(Wm, Wn, Wk, computeT, reduceT, Dk, Dv, shmem_max, bcrange = nothing, brrange = nothing) + isnothing(bcrange) && (bcrange = 2:32) + isnothing(brrange) && (brrange = 2:32) + D = max(Dk, Dv) + szC = sizeof(computeT) + szR = sizeof(reduceT) + a = szC + szR + b = szC * Dk + szR * Dv + szR * 5 + c = szC * D + Brmax = Wm + Bcmax = Wn + shmem_possible = -1 + d = typemax(Float64) + for bc = bcrange, br = brrange + x = br * Wm + y = bc * Wn + shmem = a * x * y + b * x + c * y + if shmem_possible >> 1 < shmem <= shmem_max + d2 = abs(x - y) / 16 + abs(shmem_max - shmem) / 2048 + if shmem_possible == -1 || d2 < d + d = d2 + shmem_possible = shmem + Brmax = x + Bcmax = y + end + end + end + shmem_possible = a * Brmax * Bcmax + b * Brmax + c * Bcmax + return Brmax, Bcmax, shmem_possible +end +function find_fw_max_size( + ::Type{<:FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}}, Dk, Dv, shmem_max, bcrange = nothing, brrange = nothing +) where {Wm, Wn, Wk, computeT, reduceT} + return find_fw_max_size(Wm, Wn, Wk, computeT, reduceT, Dk, Dv, shmem_max, bcrange, brrange) +end + +function build_fw_flash_attention_kernel( + configT::Type{<:FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}}, + O, L, Q, K, V +) where {Wm, Wn, Wk, computeT, reduceT} + dk, Nq, Bq = size(Q) + dv, Nk, Bk = size(V) + (; Dk, Dv, D) = find_d_dims(configT, dk, dv) + dev = device() + shmem_max, shmem_max_possible = get_max_shmem(dev) + Br, Bc, shmem = find_fw_max_size(configT, Dk, Dv, shmem_max, Nk <= 32 ? 2 : nothing, Nq <= 32 ? 2 : nothing) + if Br < Nq && Bc < Nk && shmem_max_possible > shmem_max + Br, Bc, shmem = find_fw_max_size(configT, Dk, Dv, shmem_max_possible) + end + @debug Br, Bc, shmem + mma_config = MMAConfig{Wm, Wn, Wk, computeT, reduceT} + sconfig = FlashAttenConfig{mma_config}( + @NamedTuple{ + Br::Int32, Bc::Int32, + Dk::Int32, Dv::Int32}(( + Br, Bc, Dk, Dv))) + dconfig = FlashAttenConfig{mma_config}( + @NamedTuple{ + minval::reduceT, ss::reduceT}(( + -1e9, sqrt(inv(dk))))) + config = FlashAttenKernelConfig(sconfig, dconfig) + return build_fw_flash_attention_kernel(config, O, L, Q, K, V) +end + +function build_fw_flash_attention_kernel( + config::FlashAttenConfig{mma_config}, + O, L, Q, K, V +) where {mma_config} + (; Br, Bc, Dk, Dv) = config + sconfig = @NamedTuple{Br::Int32, Bc::Int32, Dk::Int32, Dv::Int32}((Br, Bc, Dk, Dv)) + dconfig = Base.structdiff(getfields(config, :fields), sconfig) + kconfig = FlashAttenKernelConfig(FlashAttenConfig{mma_config}(sconfig), FlashAttenConfig{mma_config}(dconfig)) + return build_fw_flash_attention_kernel(kconfig, O, L, Q, K, V) +end + +function build_fw_flash_attention_kernel( + config::FlashAttenKernelConfig, + O, L, Q, K, V +) + Br = config.Br + Nq = size(Q, 2) + Bk = size(K, 3) + dev = device() + ws = warpsize(dev) + shmem_max, shmem_max_possible = get_max_shmem(dev) + shmem = total_fw_shm_size(config) + @assert shmem <= shmem_max_possible + fastmath = use_fastmath() + kernel = @cuda(always_inline=true, fastmath=fastmath, launch=false, + flash_attention_forward_kernel!(config, O, L, Q, K, V)) + if shmem > shmem_max + CUDA.cuFuncSetAttribute(kernel.fun, CUDA.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shmem) + end + compute_threads(threads) = max(fld(threads, ws), 1) * ws + compute_shmem(threads) = shmem + kernel_config = launch_configuration(kernel.fun; shmem = compute_shmem ∘ compute_threads) + threads = compute_threads(kernel_config.threads) + blocks = min(kernel_config.blocks, Bk * cld(Nq, Br)) + @debug kernel_config + return config, kernel, (; threads, blocks, shmem) +end + +function flash_attention_forward(Q, K, V) + O = similar(Q, size(V, 1), Base.tail(size(Q))...) + L = similar(O, 1, Base.tail(size(Q))...) + reduceT = eltype(Q) + computeT = get_compute_precision(reduceT) + Wm = Wn = Wk = 16 + configT = FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}} + config, kernel, kernel_config = build_fw_flash_attention_kernel(configT, O, L, Q, K, V) + kernel(config, O, L, Q, K, V; kernel_config...) + return O, L +end + +function flash_attention_forward(config, Q, K, V) + O = similar(Q, size(V, 1), Base.tail(size(Q))...) + L = similar(O, 1, Base.tail(size(Q))...) + config, kernel, kernel_config = build_fw_flash_attention_kernel(config, O, L, Q, K, V) + kernel(config, O, L, Q, K, V; kernel_config...) + return O, L +end diff --git a/src/flash/mma.jl b/src/flash/mma.jl new file mode 100644 index 0000000..45fa1df --- /dev/null +++ b/src/flash/mma.jl @@ -0,0 +1,168 @@ +using SIMD +using CUDA +using CUDA.WMMA +using CUDA: i32 +using KernelAbstractions.Extras: @unroll + +function fragment_idx1(lane, lda) + lane0 = lane - 1i32 + cgrp, rgrp = _fast_fldmod(lane0, Val(4)) + offset = cgrp * lda + rgrp << 1i32 + idx0 = Vec{4, Int32}((0i32, 1i32, 8i32, 9i32)) + idx1 = idx0 + (lda << 3i32) + idx = offset + Vec{8, Int32}(tuplejoin(idx0.data, idx1.data)) + return idx +end +function fragment_idx2(lane, lda) + lane0 = lane - 1i32 + rgrp, cgrp = _fast_fldmod(lane0, Val(4)) + offset = (cgrp << 1i32) * lda + rgrp + idx0 = Vec{4, Int32}((0i32, 1i32, 8i32, 9i32)) * lda + idx1 = idx0 + 8i32 + idx = offset + Vec{8, Int32}(tuplejoin(idx0.data, idx1.data)) + return idx +end +function fragment_idx3(lane, lda) + idx = fragment_idx1(lane, lda) + idx = shufflevector(idx, Val((0,1,4,5,2,3,6,7))) + return idx +end +function fragment_idx4(lane, lda) + idx = fragment_idx2(lane, lda) + idx = shufflevector(idx, Val((0,1,4,5,2,3,6,7))) + return idx +end + +fragment_a_idx(::Type{ColMajor}, lane, lda) = fragment_idx4(lane, lda) +fragment_a_idx(::Type{RowMajor}, lane, lda) = fragment_idx3(lane, lda) +fragment_b_idx(::Type{ColMajor}, lane, lda) = fragment_idx1(lane, lda) +fragment_b_idx(::Type{RowMajor}, lane, lda) = fragment_idx2(lane, lda) +fragment_c_idx(::Type{ColMajor}, lane, lda) = fragment_idx4(lane, lda) +fragment_c_idx(::Type{RowMajor}, lane, lda) = fragment_idx3(lane, lda) +fragment_idx(::Type{MatrixA}, ::Type{L}, lane, lda) where L <: FragmentLayout = fragment_a_idx(L, lane, lda) +fragment_idx(::Type{MatrixB}, ::Type{L}, lane, lda) where L <: FragmentLayout = fragment_b_idx(L, lane, lda) +fragment_idx(::Type{Accumulator}, ::Type{L}, lane, lda) where L <: FragmentLayout = fragment_c_idx(L, lane, lda) + + +struct MMAConfig{M, N, K, m_type, d_type} end + +Base.@propagate_inbounds function fragment_load( + config::Union{Type{<:MMAConfig{16, 16, 16}}, Type{<:Config{16, 16, 16}}}, use::Type{U}, layout::Type{L}, lane, + mem, r, c, b... +) where {U <: WMMA.FragmentUse, L <: FragmentLayout} + base = LinearIndices(size(mem))[r, c, b...] + lda = stride(mem, 2) * i32 + indices = base + fragment_idx(U, L, lane, lda) + vs = Base.Cartesian.@ntuple 8 i -> mem[indices[i]] + if !(U <: Accumulator) + vs = tuplejoin(vs, vs) + end + return vs +end + +Base.@propagate_inbounds function fragment_load_a( + config::Type{<:MMAConfig{16, 16, 16, T}}, layout::Type{L}, lane, + data, r, c, b... +) where {T, L <: FragmentLayout} + return Fragment{16, 16, 16, 16, T, L, MatrixA}(fragment_load(config, MatrixA, layout, lane, data, r, c, b...)) +end +Base.@propagate_inbounds function fragment_load_a( + config::Type{<:Config{16, 16, 16, T}}, layout::Type{L}, lane, + data, r, c, b... +) where {T, L <: FragmentLayout} + return fragment_load_a(MMAConfig{16, 16, 16, eltype(data), T}, layout, lane, data, r, c, b...) +end +Base.@propagate_inbounds function fragment_load_b( + config::Type{<:MMAConfig{16, 16, 16, T}}, layout::Type{L}, lane, + data, r, c, b... +) where {T, L <: FragmentLayout} + return Fragment{16, 16, 16, 16, T, L, MatrixB}(fragment_load(config, MatrixB, layout, lane, data, r, c, b...)) +end +Base.@propagate_inbounds function fragment_load_b( + config::Type{<:Config{16, 16, 16, T}}, layout::Type{L}, lane, + data, r, c, b... +) where {T, L <: FragmentLayout} + return fragment_load_b(MMAConfig{16, 16, 16, eltype(data), T}, layout, lane, data, r, c, b...) +end +Base.@propagate_inbounds function fragment_load_c( + config::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, layout::Type{L}, lane, + data, r, c, b... +) where {MT, T, L <: FragmentLayout} + return Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}(fragment_load(config, Accumulator, L, lane, data, r, c, b...)) +end + +Base.@propagate_inbounds function fragment_store_d( + config::Union{Type{<:MMAConfig{16, 16, 16}}, Type{<:Config{16, 16, 16}}}, layout::Type{L}, lane, + mem::AbstractArray{T}, frag::Fragment{16, 16, 16, 8, T}, r, c, b... +) where {T, L <: FragmentLayout} + base = LinearIndices(size(mem))[r, c, b...] + lda = stride(mem, 2) * i32 + indices = base + fragment_c_idx(L, lane, lda) + Base.Cartesian.@nexprs 8 i -> mem[indices[i]] = frag[i] + return nothing +end + +Base.@propagate_inbounds function fragment_reduce_store_d( + config::Union{Type{<:MMAConfig{16, 16, 16}}, Type{<:Config{16, 16, 16}}}, layout::Type{L}, lane, + op, mem::AbstractArray{T}, frag::Fragment{16, 16, 16, 8, T}, r, c, b... +) where {T, L <: FragmentLayout} + base = LinearIndices(size(mem))[r, c, b...] + lda = stride(mem, 2) * i32 + idx = base + fragment_c_idx(L, lane, lda) + vs = frag.x + Base.Cartesian.@nexprs 8 i -> CUDA.atomic_arrayset(mem, idx[i], op, vs[i]) + return nothing +end + +function shfl_dot( + config::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, + frag_a::Fragment{16, 16, 16, 16, MT}, + frag_b::Fragment{16, 16, 16, 16, MT}, + odd_even +) where {T, MT} + @inbounds begin + shflmask = typemax(UInt32) + odd = odd_even[1] + even = odd_even[2] + a = WMMA.unflatten(NTuple{16, VecElement{MT}}, frag_a.x) + b = WMMA.unflatten(NTuple{16, VecElement{MT}}, frag_b.x) + a12, a34, a56, a78 = tuplesplitp(first(tuplesplitp(a, Val(2))), Val(4)) + a1256, a3478 = tuplejoin(a12, a56), tuplejoin(a34, a78) + b1234, b5678 = tuplesplitp(first(tuplesplitp(b, Val(2))), Val(2)) + sb1234 = _shfl_sync(shflmask, b1234, odd) + sb5678 = _shfl_sync(shflmask, b5678, odd) + lb1234 = _shfl_sync(shflmask, b1234, even) + lb5678 = _shfl_sync(shflmask, b5678, even) + c12 = (VecElement{T}(_dot(T, a1256, sb1234)), VecElement{T}(_dot(T, a1256, lb1234))) + c34 = (VecElement{T}(_dot(T, a3478, sb1234)), VecElement{T}(_dot(T, a3478, lb1234))) + c56 = (VecElement{T}(_dot(T, a1256, sb5678)), VecElement{T}(_dot(T, a1256, lb5678))) + c78 = (VecElement{T}(_dot(T, a3478, sb5678)), VecElement{T}(_dot(T, a3478, lb5678))) + return Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}(WMMA.flatten(tuplejoin(c12, c34, c56, c78))) + end +end + +function fragment_mma( + config::Union{Type{<:MMAConfig{16, 16, 16, MT, T}}, Type{<:Config{16, 16, 16, T}}}, lane, + frag_a::Fragment{16, 16, 16, 16, MT, L1, MatrixA}, + frag_b::Fragment{16, 16, 16, 16, MT, L2, MatrixB}, + frag_c::Fragment{16, 16, 16, 8, T, Unspecified, Accumulator}, +) where {MT, T, L1, L2} + shflmask = typemax(UInt32) + ws = Val{32}() + lane0 = lane - 1i32 + agrp, bgrp = _fast_fldmod(lane0, Val(4)) + base = lane - bgrp # 1i32 + _fast_mul(agrp, Val(4)) + odd_even = (_fast_mul(bgrp, Val(8)) + bgrp) + Vec{2, Int32}((1i32, 5i32)) + frag_c = frag_c .+ shfl_dot(config, frag_a, frag_b, odd_even) + grp = bgrp + odd_even = _fast_mod(odd_even + 7i32, ws) + 1i32 + grp = _fast_mod(grp + 3i32, Val(4)) + frag_c = frag_c .+ _shfl_sync(shflmask, shfl_dot(config, frag_a, frag_b, odd_even), grp + base) + odd_even = _fast_mod(odd_even + 7i32, ws) + 1i32 + grp = _fast_mod(grp + 3i32, Val(4)) + frag_c = frag_c .+ _shfl_sync(shflmask, shfl_dot(config, frag_a, frag_b, odd_even), grp + base) + odd_even = _fast_mod(odd_even + 7i32, ws) + 1i32 + grp = _fast_mod(grp + 3i32, Val(4)) + frag_c = frag_c .+ _shfl_sync(shflmask, shfl_dot(config, frag_a, frag_b, odd_even), grp + base) + return frag_c +end diff --git a/src/flash/utils.jl b/src/flash/utils.jl new file mode 100644 index 0000000..e762aba --- /dev/null +++ b/src/flash/utils.jl @@ -0,0 +1,141 @@ +using CUDA +using CUDA.WMMA + +@generated function tuplesplit(t::NTuple{N}, ::Val{M}) where {N, M} + expr = Expr(:tuple) + for indices in Iterators.partition(1:N, M) + ti = Expr(:tuple) + for idx in indices + push!(ti.args, :(t[$idx])) + end + push!(expr.args, ti) + end + return quote + @inbounds $expr + end +end + +@generated function tuplesplitp(t::NTuple{N}, ::Val{M}) where {N, M} + N % M == 0 || error("Cannot split tuple of length $N into $M pieces") + n = cld(N, M) + return :(tuplesplit(t, Val{$n}())) +end + +tuplejoin(a::Tuple, b::Tuple, c::Tuple...) = tuplejoin((a..., b...), c...) +tuplejoin(a::Tuple, b::Tuple) = (a..., b...) + +@generated function _fast_mul(x, ::Val{V}) where V + ispow2(V) || error("No fast mod for $V") + v = Int32(trailing_zeros(V)) + return isone(V) ? :x : :(x << $v) +end +@generated function _fast_fld(x, ::Val{V}) where V + ispow2(V) || error("No fast mod for $V") + v = Int32(trailing_zeros(V)) + return isone(V) ? :x : :(x >> $v) +end +@generated function _fast_mod(x, ::Val{V}) where V + ispow2(V) || error("No fast mod for $V") + v = Int32(V) - 1i32 + return :(x & $v) +end +_fast_fldmod(x, v::Val{V}) where V = (_fast_fld(x, v), _fast_mod(x, v)) + +@generated function _uint(::Type{T}) where T + s = sizeof(T) + if s == 1 + return UInt8 + elseif s == 2 + return UInt16 + elseif s == 4 + return UInt32 + elseif s == 8 + return UInt64 + else + error("no corresponding unsigned type for $T") + end +end + +@generated function _pack_4byte(x::NTuple{N, UT}) where {N, T, UT <: Union{T, VecElement{T}}} + nb = N * sizeof(T) + n, r = fldmod(nb, 4) + x′ = Expr(:tuple) + for i = 1:N + xi = UT <: VecElement ? :(x[$i].value) : :(x[$i]) + xi = :(reinterpret($(_uint(T)), $xi)) + push!(x′.args, xi) + end + expr = Expr(:tuple) + for i = 1:n + push!(expr.args, :(reinterpret(UInt32, Vec{4, UInt8}(xs[$i])))) + end + if !iszero(r) + rest = Expr(:tuple, Expr(:..., :(xs[$(n+1)])), [:(zero(T)) for _ in 1:(4 - r)]...) + push!(expr.args, :(reinterpret(UInt32, Vec{4, UInt8}($rest)))) + end + return quote + @inbounds begin + x′ = $x′ + xs = tuplesplit(WMMA.flatten(reinterpret(Vec{$nb, UInt8}, Vec(x′)).data), Val(4)) + $expr + end + end +end + +@generated function _unpack_4byte(::Type{NTuple{N, UT}}, xs::NTuple{M, UInt32}) where {N, M, T, UT <: Union{T, VecElement{T}}} + expr = Expr(:tuple) + for i = 1:N + xi = :(reinterpret($T, reinterpret($(_uint(T)), Vec(xs[$i])))) + xi = UT <: VecElement ? :(VecElement{$T}($xi)) : xi + push!(expr.args, xi) + end + return quote + @inbounds begin + xs = tuplesplit(WMMA.flatten(reinterpret(Vec{$(4M), UInt8}, Vec(xs)).data), Val{$(sizeof(T))}()) + return $expr + end + end +end + +@generated function _shfl(shfl_op, mask, vals::NTuple{N, UT}, src) where {N, T, UT <: Union{T, VecElement{T}}} + nb = N * sizeof(T) + n = cld(nb, 4) + expr = Expr(:tuple) + for i = 1:n + push!(expr.args, :(shfl_op(mask, xs[$i], src)::UInt32)) + end + return quote + @inbounds begin + xs = _pack_4byte(vals) + xs = $expr + return _unpack_4byte(NTuple{N, UT}, xs) + end + end +end +_shfl_sync(mask, vals::NTuple{N, UT}, src) where {N, T, UT <: Union{T, VecElement{T}}} = _shfl(shfl_sync, mask, vals, src) +_shfl_sync(mask, vals::F, src) where F <: Fragment = F(_shfl_sync(mask, vals.x, src)) +_shfl_xor_sync(mask, vals::NTuple{N, UT}, src) where {N, T, UT <: Union{T, VecElement{T}}} = _shfl(shfl_xor_sync, mask, vals, src) +_shfl_xor_sync(mask, vals::F, src) where F <: Fragment = F(_shfl_xor_sync(mask, vals.x, src)) + +@generated function _dot(::Type{T}, a::NTuple{N, MT}, b::NTuple{N, MT}, c::Union{T, VecElement{T}, Nothing} = nothing) where {T, N, MT} + ab = Expr(:tuple) + for i = 1:N + push!(ab.args, MT <: VecElement ? :((a[$i].value * b[$i].value)) : :(a[$i] * b[$i])) + end + expr = :($T(ab[1])) + for i = 2:N + expr = Expr(:call, :(+), expr, :($T(ab[$i]))) + end + if !(c <: Nothing) + expr = Expr(:call, :(+), c <: VecElement ? :(c.value) : :c, expr) + end + if c <: VecElement + expr = :(VecElement{$T}($expr)) + end + return quote + @inbounds begin + ab = $ab + return $expr + end + end +end From 1c3e98ad557b232a1db1a5ccc2130242c6b1df94 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Sat, 3 Feb 2024 20:44:32 +0800 Subject: [PATCH 2/2] init backward impl --- src/flash/Flash.jl | 2 + src/flash/backward.jl | 242 ++++++++++++++++++++++++++++++++++++ src/flash/backward_utils.jl | 65 ++++++++++ src/flash/forward_utils.jl | 5 + src/flash/launch.jl | 138 ++++++++++++++++++++ 5 files changed, 452 insertions(+) create mode 100644 src/flash/backward.jl create mode 100644 src/flash/backward_utils.jl diff --git a/src/flash/Flash.jl b/src/flash/Flash.jl index eb487cb..aa6f855 100644 --- a/src/flash/Flash.jl +++ b/src/flash/Flash.jl @@ -4,6 +4,8 @@ include("utils.jl") include("mma.jl") include("forward_utils.jl") include("forward.jl") +include("backward_utils.jl") +include("backward.jl") include("launch.jl") end diff --git a/src/flash/backward.jl b/src/flash/backward.jl new file mode 100644 index 0000000..00f878c --- /dev/null +++ b/src/flash/backward.jl @@ -0,0 +1,242 @@ +using CUDA +using CUDA: i32 +using KernelAbstractions.Extras: @unroll + +function create_backward_dynamic_shm(config) + (; Br, Bc) = config + (; Wk, Wn) = config + (; Dk, Dv) = config + (; computeT, reduceT) = config + D = max(Dk, Dv) + offset = 0i32 + Li = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(Li) % Int32 + Di = @inbounds CuDynamicSharedArray(reduceT, (1, Br), offset) + offset += sizeof(Di) % Int32 + P = @inbounds CuDynamicSharedArray(computeT, (Bc, Br), offset) + offset += sizeof(P) % Int32 + dS = @inbounds CuDynamicSharedArray(computeT, (Bc, Br), offset) + offset += sizeof(dS) % Int32 + Qi = @inbounds CuDynamicSharedArray(computeT, (Dk, Br), offset) + offset += sizeof(Qi) % Int32 + dQi = @inbounds CuDynamicSharedArray(reduceT, (Dk, Br), offset) + offset += sizeof(dQi) % Int32 + Kj = @inbounds CuDynamicSharedArray(computeT, (Dk, Bc), offset) + offset += sizeof(Kj) % Int32 + dKj = @inbounds CuDynamicSharedArray(reduceT, (Dk, Bc), offset) + offset += sizeof(dKj) % Int32 + Vj = @inbounds CuDynamicSharedArray(computeT, (Dv, Bc), offset) + offset += sizeof(Vj) % Int32 + dVj = @inbounds CuDynamicSharedArray(reduceT, (Dv, Bc), offset) + offset += sizeof(dVj) % Int32 + dOi = @inbounds CuDynamicSharedArray(computeT, (Dv, Br), offset) + offset += sizeof(dOi) % Int32 + return (; P, dS, Kj, dKj, Vj, dVj, dOi, Qi, dQi, Li, Di) +end + +@inline function compute_P_and_dS(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dk, Dv) = config + (; computeT, reduceT) = config + (; ss, minval) = config + (; Qi, Kj, Vj, Li, dOi, P, dS, Di) = shms + (; W, warp, lane) = grps + (; sfull, klen, qlen, tr, tc, tdk, tdv) = sizes + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tr * tc + @unroll for ti_tj = warp:W:np + ti, tj = fldmod1(ti_tj, tc) + sj = (tj - 1i32) * Wn + 1i32 + si = (ti - 1i32) * Wm + 1i32 + (si > rstop || sj > cstop) && continue + d = warp_load_1xm(config, lane, Di, si) + dp = warp_fill_c(config, zero(reduceT)) + @unroll for tk = 1i32:tdv + sk = (tk - 1i32) * Wk + 1i32 + sk > Dv && break + doi = warp_load_kxm(config, lane, dOi, sk, si) + v = warp_load_kxn(config, lane, Vj, sk, sj) + dp = warp_mma(config, lane, v, doi, dp) # dp += v' * doi + end + dpd = rbroadcast(-, dp, d) + l = warp_load_1xm(config, lane, Li, si) + s = warp_fill_c(config, zero(reduceT)) + @unroll for tk = 1i32:tdk + sk = (tk - 1i32) * Wk + 1i32 + sk > Dk && break + q = warp_load_kxm(config, lane, Qi, sk, si) + k = warp_load_kxn(config, lane, Kj, sk, sj) + s = warp_mma(config, lane, k, q, s) # s += k' * q + end + s = s .* ss + p = rbroadcast(-, s, l) + if !sfull + pmask = warp_gen_pmask(config, lane, klen, qlen, sj, si) + p = _elseif.(p, pmask, minval) # @. ifelse(pmask, p, minval) + end + p = exp.(p) + warp_shm_write_nxm!(config, lane, P, fragment_type(computeT, p), sj, si) + ds = p .* dpd + ds = ds .* ss + warp_shm_write_nxm!(config, lane, dS, fragment_type(computeT, ds), sj, si) + end + return nothing +end + +@inline function compute_dQ_KdS(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dk) = config + (; computeT, reduceT) = config + (; ss) = config + (; dQi, Kj, dS) = shms + (; W, warp, lane) = grps + (; klen, qlen) = sizes + tr = cld(Br, Wk) % Int32 + tc = cld(Bc, Wm) % Int32 + tdk = cld(Dk, Wn) % Int32 + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tr * tdk + @unroll for ti_tk = warp:W:np + ti, tk = fldmod1(ti_tk, tdk) + si = (ti - 1i32) * Wm + 1i32 + sk = (tk - 1i32) * Wn + 1i32 + (si > rstop || sk > Dk) && continue + dq = warp_fill_c(config, zero(reduceT)) + @unroll for tj = 1i32:tc + sj = (tj - 1i32) * Wk + 1i32 + sj > cstop && break + ds = warp_load_kxm(config, lane, dS, sj, si) + k = warp_load_nxk(config, lane, Kj, sk, sj) + dq = warp_mma(config, lane, k, ds, dq) # dQi += K * dSi + end + warp_shm_write_nxm!(config, lane, dQi, dq, sk, si) + end + return nothing +end + +@inline function compute_dV_dOPᵀ(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dv) = config + (; computeT, reduceT) = config + (; P, dVj, dOi) = shms + (; W, warp, lane) = grps + (; klen, qlen) = sizes + (; Tc) = sizes + tc = cld(Bc, Wm) % Int32 + tdv = cld(Dv, Wn) % Int32 + tr = cld(Br, Wk) % Int32 + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tc * tdv + @unroll for tj_tk = warp:W:np + tj, tk = fldmod1(tj_tk, tdv) + sj = (tj - 1i32) * Wm + 1i32 + sk = (tk - 1i32) * Wn + 1i32 + (sj > cstop || sk > Dv) && continue + dv = warp_load_nxm(config, lane, dVj, sk, sj) + @unroll for ti = 1i32:tr + si = (ti - 1i32) * Wk + 1i32 + si > rstop && break + p = warp_load_mxk(config, lane, P, sj, si) + doi = warp_load_nxk(config, lane, dOi, sk, si) + dv = warp_mma(config, lane, doi, p, dv) # dvj += doi * p' + end + warp_shm_write_nxm!(config, lane, dVj, dv, sk, sj) + end + return nothing +end + +@inline function compute_dK_QdSᵀ(config, grps, shms, sizes) + (; Br, Bc, Wm, Wn, Wk, Dk) = config + (; computeT, reduceT) = config + (; ss) = config + (; Qi, dKj, dS) = shms + (; W, warp, lane) = grps + (; klen, qlen) = sizes + tr = cld(Br, Wk) % Int32 + tc = cld(Bc, Wm) % Int32 + tdk = cld(Dk, Wn) % Int32 + rstop = min(Br, qlen) + cstop = min(Bc, klen) + np = tc * tdk + @unroll for tj_tk = warp:W:np + tj, tk = fldmod1(tj_tk, tdk) + sj = (tj - 1i32) * Wm + 1i32 + sk = (tk - 1i32) * Wn + 1i32 + (sj > cstop || sk > Dk) && continue + dk = warp_load_nxm(config, lane, dKj, sk, sj) + @unroll for ti = 1i32:tr + si = (ti - 1i32) * Wk + 1i32 + si > rstop && break + ds = warp_load_mxk(config, lane, dS, sj, si) + q = warp_load_nxk(config, lane, Qi, sk, si) + dk = warp_mma(config, lane, q, ds, dk) # dKj += Q * dSi' + end + warp_shm_write_nxm!(config, lane, dKj, dk, sk, sj) + end + return nothing +end + +function flash_attention_backward_kernel!(config, dQ, dK, dV, dO, O, L, Q, K, V) + (; Br, Bc, Dk, Dv) = config + Wm = config.Wm % Int32 # WMMA size + Wn = config.Wn % Int32 + Wk = config.Wk % Int32 + (; computeT, reduceT) = config + (; minval, ss) = config + threads = blockDim().x + ws = warpsize() + W = fld(threads, ws) + index = threadIdx().x + warp, lane = fldmod1(index, ws) + grps = (; W, index, warp, lane) + B = size(O, 3) % Int32 + Nq = size(Q, 2) % Int32 + Nk = size(K, 2) % Int32 + dk = size(Q, 1) % Int32 + dv = size(V, 1) % Int32 + Tr = cld(Nq, Br) % Int32 + Tc = cld(Nk, Bc) % Int32 + tr = cld(Br, Wm) % Int32 + tc = cld(Bc, Wn) % Int32 + tdk = cld(Dk, Wk) % Int32 + tdv = cld(Dv, Wn) % Int32 + sizes = (; Nq, Nk, Tr, Tc, dk, dv, tr, tc, tdk, tdv) + shms = create_backward_dynamic_shm(config) + stride = gridDim().x + bidx = blockIdx().x + NP = B * Tc + for b_j = bidx:stride:NP + b, j = fldmod1(b_j, Tc) + kfull, krange, klen = chunkrange(Bc, Nk, j) + sizes = merge(sizes, (; kfull, krange, klen)) + block_glb2shm!(config, shms.Kj, K, krange, b) + block_glb2shm!(config, shms.Vj, V, krange, b) + block_shm_fill!(shms.dKj, zero(reduceT)) + block_shm_fill!(shms.dVj, zero(reduceT)) + block_shm_fill!(shms.Di, zero(reduceT)) + sync_threads() + for i = 1i32:Tr + qfull, qrange, qlen = chunkrange(Br, Nq, i) + sfull = qfull & kfull + sizes = merge(sizes, (; sfull, qfull, qrange, qlen,)) + block_glb2shm!(config, shms.dOi, dO, qrange, b) + block_glb2shm!(config, shms.Li, L, qrange, b) + block_glb2shm_rowreduce_atomic!(config, *, shms.Di, shms.dOi, O, qrange, b) + block_glb2shm!(config, shms.Qi, Q, qrange, b) + sync_threads() + compute_P_and_dS(config, grps, shms, sizes) + sync_threads() + block_shm_fill!(shms.Di, zero(reduceT)) + compute_dV_dOPᵀ(config, grps, shms, sizes) + compute_dK_QdSᵀ(config, grps, shms, sizes) + compute_dQ_KdS(config, grps, shms, sizes) + sync_threads() + block_shm2glb_atomic!(config, +, dQ, shms.dQi, qrange, b) + sync_threads() + end + block_shm2glb!(config, dK, shms.dKj, krange, b) + block_shm2glb!(config, dV, shms.dVj, krange, b) + sync_threads() + end + return nothing +end diff --git a/src/flash/backward_utils.jl b/src/flash/backward_utils.jl new file mode 100644 index 0000000..a39f2cf --- /dev/null +++ b/src/flash/backward_utils.jl @@ -0,0 +1,65 @@ +@inline function block_glb2shm_reduce!(config, op, shm, dOi, glb, range, b) + D = size(shm, 1) % Int32 + d = size(glb, 1) % Int32 + N = length(range) % Int32 + cbase = first(range) - 1i32 + works = length(shm) % Int32 + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + for i = base:min(top, works) + c, r = fldmod1(i, D) + @inbounds if r > d || c > N + shm[r, c] = zero(eltype(shm)) + else + shm[r, c] = convert(eltype(shm), op(dOi[r, c], glb[r, c + cbase, b])) + end + end + return nothing +end + +@inline function block_glb2shm_rowreduce_atomic!(config, op, shm, dOi, glb, range, b) + D = size(dOi, 1) % Int32 + d = size(glb, 1) % Int32 + N = length(range) % Int32 + cbase = first(range) - 1i32 + works = length(dOi) % Int32 + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + stop = min(top, works) + acc = zero(eltype(shm)) + for i = base:stop + c, r = fldmod1(i, D) + @inbounds if r <= d && c <= N + v = oftype(acc, op(dOi[r, c], glb[r, c + cbase, b])) + else + v = zero(acc) + end + acc += v + if r == D || i == stop + @inbounds CUDA.atomic_arrayset(shm, c, +, acc) + acc = zero(eltype(shm)) + end + end + return nothing +end + +function block_shm2glb_atomic!(config, op, glb, shm, range, b = 1) + # size(shm) = (D, Bx) + # size(glb) = (d, N, B) + # assume D >= d, Bx >= N + D = size(shm, 1) % Int32 + d = size(glb, 1) % Int32 + N = length(range) % Int32 + cbase = first(range) - 1i32 + works = (d * N) % Int32 + workload = cld(works, blockDim().x) % Int32 + top = threadIdx().x * workload + base = top - workload + 1i32 + for i = base:min(top, works) + c, r = fldmod1(i, d) + @inbounds CUDA.atomic_arrayset(glb, (r, c + cbase, b), op, shm[r, c]) + end + return nothing +end diff --git a/src/flash/forward_utils.jl b/src/flash/forward_utils.jl index 2ddbdfb..e7876b5 100644 --- a/src/flash/forward_utils.jl +++ b/src/flash/forward_utils.jl @@ -191,6 +191,11 @@ end (; computeT, reduceT) = config return @inbounds fragment_load_b(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, ColMajor, lane, Qi, sk, si) end +@inline function warp_load_mxk(config, lane, Qi, sk, si) + (; Wm, Wn, Wk) = config + (; computeT, reduceT) = config + return @inbounds fragment_load_b(MMAConfig{Wn, Wm, Wk, computeT, reduceT}, RowMajor, lane, Qi, sk, si) +end @inline function warp_load_kxn(config, lane, Kj, sk, sj) (; Wm, Wn, Wk) = config (; computeT, reduceT) = config diff --git a/src/flash/launch.jl b/src/flash/launch.jl index 12f2ad0..2369bc1 100644 --- a/src/flash/launch.jl +++ b/src/flash/launch.jl @@ -249,3 +249,141 @@ function flash_attention_forward(config, Q, K, V) kernel(config, O, L, Q, K, V; kernel_config...) return O, L end + +function total_bw_shm_size(config) + @assert(hasproperty(config, :Br) && hasproperty(config, :Bc) && hasproperty(config, :Dk) && hasproperty(config, :Dv), + "Cannot compute shmem size without knowing the size of inputs") + (; Br, Bc, Dk, Dv) = config + (; computeT, reduceT) = config + # dkdv_size = sizeof(computeT) * (2 * Br * Bc + (Dk + Dv) * Bc + (max(Dk, Dv) + Dv) * Br) + sizeof(reduceT) * (2 * Br + (Dk + Dv) * Bc) + # dq_size = sizeof(computeT) * (Br * Bc + (Dk + Dv) * Br + (Dk + Dv) * Bc) + sizeof(reduceT) * (2 * Br + (Dk + Dv) * Br) + # return max(dkdv_size, dq_size) + return sizeof(computeT) * (2 * Br * Bc + (Dk + Dv) * Bc + (Dk + Dv) * Br) + sizeof(reduceT) * (2 * Br + Dk * Br + (Dk + Dv) * Bc) +end + +function find_bw_max_size(Wm, Wn, Wk, computeT, reduceT, Dk, Dv, shmem_max, bcrange = nothing, brrange = nothing) + isnothing(bcrange) && (bcrange = 1:32) + isnothing(brrange) && (brrange = 1:32) + D = max(Dk, Dv) + szC = sizeof(computeT) + szR = sizeof(reduceT) + a = 2 * szC + # b = szC * (D + Dv) + szR * 2 + b = szC * (Dk + Dv) + szR * (2 + Dk) + c = (szC + szR) * (Dk + Dv) + Brmax = Wm + Bcmax = Wn + shmem_possible = -1 + d = typemax(Float64) + for bc = bcrange, br = brrange + x = br * Wm + y = bc * Wn + shmem = a * x * y + b * x + c * y + if shmem_possible >> 1 < shmem <= shmem_max + d2 = abs(x - y) / 16 + abs(shmem_max - shmem) / 2048 + if shmem_possible == -1 || d2 < d + d = d2 + shmem_possible = shmem + Brmax = x + Bcmax = y + end + end + end + shmem_possible = a * Brmax * Bcmax + b * Brmax + c * Bcmax + return Brmax, Bcmax, shmem_possible +end +function find_bw_max_size( + ::Type{<:FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}}, Dk, Dv, shmem_max, bcrange = nothing, brrange = nothing +) where {Wm, Wn, Wk, computeT, reduceT} + return find_bw_max_size(Wm, Wn, Wk, computeT, reduceT, Dk, Dv, shmem_max, bcrange, brrange) +end + +function build_bw_flash_attention_kernel( + configT::Type{<:FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}}}, + dQ, dK, dV, dO, O, L, Q, K, V +) where {Wm, Wn, Wk, computeT, reduceT} + dk, Nq, Bq = size(Q) + dv, Nk, Bk = size(V) + (; Dk, Dv, D) = find_d_dims(configT, dk, dv) + dev = device() + shmem_max, shmem_max_possible = get_max_shmem(dev) + Br, Bc, shmem = find_bw_max_size(configT, Dk, Dv, shmem_max, Nk <= 32 ? 2 : nothing, Nq <= 32 ? 2 : nothing) + if Br < Nq && Bc < Nk && shmem_max_possible > shmem_max + Br, Bc, shmem = find_bw_max_size(configT, Dk, Dv, shmem_max_possible) + end + @debug Br, Bc, shmem + mma_config = MMAConfig{Wm, Wn, Wk, computeT, reduceT} + sconfig = FlashAttenConfig{mma_config}( + @NamedTuple{ + Br::Int32, Bc::Int32, + Dk::Int32, Dv::Int32}(( + Br, Bc, Dk, Dv))) + dconfig = FlashAttenConfig{mma_config}( + @NamedTuple{ + minval::reduceT, ss::reduceT}(( + -1e9, sqrt(inv(dk))))) + config = FlashAttenKernelConfig(sconfig, dconfig) + return build_bw_flash_attention_kernel(config, dQ, dK, dV, dO, O, L, Q, K, V) +end + +function build_bw_flash_attention_kernel( + config::FlashAttenConfig{mma_config}, + dQ, dK, dV, dO, O, L, Q, K, V +) where {mma_config} + (; Br, Bc, Dk, Dv) = config + sconfig = @NamedTuple{Br::Int32, Bc::Int32, Dk::Int32, Dv::Int32}((Br, Bc, Dk, Dv)) + dconfig = Base.structdiff(getfields(config, :fields), sconfig) + kconfig = FlashAttenKernelConfig(FlashAttenConfig{mma_config}(sconfig), FlashAttenConfig{mma_config}(dconfig)) + return build_bw_flash_attention_kernel(kconfig, dQ, dK, dV, dO, O, L, Q, K, V) +end + +function build_bw_flash_attention_kernel( + config::FlashAttenKernelConfig, + dQ, dK, dV, dO, O, L, Q, K, V +) + Br = config.Br + Bc = config.Bc + Nq = size(Q, 2) + Nk = size(K, 2) + Bk = size(K, 3) + dev = device() + ws = warpsize(dev) + shmem_max, shmem_max_possible = get_max_shmem(dev) + shmem = total_bw_shm_size(config) + @assert shmem <= shmem_max_possible + fastmath = use_fastmath() + kernel = @cuda(always_inline=true, fastmath=fastmath, launch=false, + flash_attention_backward_kernel!(config, dQ, dK, dV, dO, O, L, Q, K, V)) + if shmem > shmem_max + CUDA.cuFuncSetAttribute(kernel.fun, CUDA.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shmem) + end + compute_threads(threads) = max(fld(threads, ws), 1) * ws + compute_shmem(threads) = shmem + kernel_config = launch_configuration(kernel.fun; shmem = compute_shmem ∘ compute_threads) + threads = compute_threads(kernel_config.threads) + blocks = min(kernel_config.blocks, Bk * cld(Nk, Bc)) + @debug kernel_config + return config, kernel, (; threads, blocks, shmem) +end + +function flash_attention_backward(dO, O, L, Q, K, V) + dQ = zero(Q) + dK = similar(K) + dV = similar(V) + reduceT = eltype(Q) + computeT = get_compute_precision(reduceT) + Wm = Wn = Wk = 16 + configT = FlashAttenConfig{MMAConfig{Wm, Wn, Wk, computeT, reduceT}} + config, kernel, kernel_config = build_bw_flash_attention_kernel(configT, dQ, dK, dV, dO, O, L, Q, K, V) + kernel(config, dQ, dK, dV, dO, O, L, Q, K, V; kernel_config...) + return dQ, dK, dV +end + +function flash_attention_backward(config, dO, O, L, Q, K, V) + dQ = zero(Q) + dK = similar(K) + dV = similar(V) + config, kernel, kernel_config = build_bw_flash_attention_kernel(config, dQ, dK, dV, dO, O, L, Q, K, V) + kernel(config, dQ, dK, dV, dO, O, L, Q, K, V; kernel_config...) + return dQ, dK, dV +end