From 112507ba1b6e23dd68d854991794ac1b7ec87bc3 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 28 Dec 2024 11:51:03 +0100 Subject: [PATCH 1/3] SDPA tricks --- Project.toml | 2 +- src/Jjama3.jl | 4 +- src/layers.jl | 32 +++---- src/model.jl | 17 +--- src/sdpa.jl | 250 ++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 273 insertions(+), 32 deletions(-) create mode 100644 src/sdpa.jl diff --git a/Project.toml b/Project.toml index 2641079..fda7547 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ MetalExt = "Metal" [compat] ChainRulesCore = "1.25.0" -Flux = "0.14, 0.15" +Flux = "0.14, 0.15, 0.16" LogitSamplers = "0.1.0" LowRankLayers = "0.1.2" Metal = "1" diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 00d112c..cb11bc0 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -6,7 +6,7 @@ using LinearAlgebra using NNlib using LogitSamplers using LowRankLayers -#using ChainRulesCore +using ChainRulesCore using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer @@ -28,7 +28,7 @@ export rerope_cache! export scrape_cache export append_cache! -#include("sdpa.jl") +include("sdpa.jl") include("model.jl") export forward_loss diff --git a/src/layers.jl b/src/layers.jl index 53911fa..c3c8e40 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -101,6 +101,12 @@ function unrope(rope, x) ) end +#Scaled dot product attention +function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T + A = softmax(batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) .+ mask; dims=1) + return batched_mul(xv, A) +end + mutable struct Attention wq::AnyDense wk::AnyDense @@ -111,11 +117,12 @@ mutable struct Attention n_kv_heads::Int head_dim::Int cache::KVCache + sdpa_func::Function end Flux.@layer Attention trainable=(wq,wv) -function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) +function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false, sdpa_func = sdpa) head_dim = dim ÷ n_heads Attention( Dense(dim => n_heads * head_dim, bias=qkv_bias), @@ -127,18 +134,12 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) n_kv_heads, head_dim, KVCache(Float32; n_kv_heads, head_dim), + sdpa_func ) end repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) -function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T - scores = batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) - scores = scores .+ mask - sm_scores = softmax(scores; dims=1) - return batched_mul(xv, sm_scores) -end - function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T _, seqlen, batch = size(x) @@ -168,7 +169,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - output = sdpa(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) + output = attn.sdpa_func(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) @@ -177,9 +178,6 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing return proj end - - - struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} attention::A feed_forward::F @@ -189,17 +187,17 @@ end function TransformerBlock( dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5, qkv_bias=false, + norm_eps=1f-5, qkv_bias=false, sdpa_func = sdpa ) TransformerBlock( - Attention(dim, n_heads, n_kv_heads; qkv_bias), + Attention(dim, n_heads, n_kv_heads; qkv_bias, sdpa_func), FeedForward(dim, ff_hidden_dim), RMSNorm(dim, eps=norm_eps), RMSNorm(dim, eps=norm_eps) ) end -function (block::TransformerBlock)(x, start_pos, rope, mask=nothing) +function (block::TransformerBlock)(x, start_pos, rope, mask) h = x + block.attention(block.attention_norm(x), start_pos, rope, mask) out = h + block.feed_forward(block.ffn_norm(h)) return out @@ -207,7 +205,6 @@ end Flux.@layer TransformerBlock trainable=(attention,) - mutable struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE} tok_embeddings::E layers::B @@ -227,9 +224,10 @@ function Transformer( rope_theta::T=500000f0, use_scaled_rope=false, scale_factor=8, + sdpa_func = sdpa ) where T tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) + layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias, sdpa_func=sdpa_func) for _ in 1:n_layers) norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) #This should probably be generated to a sane length, and then extended in the forward pass if needed. diff --git a/src/model.jl b/src/model.jl index dbbebdb..e49b0bc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -42,7 +42,7 @@ end wrap(model, xs...) = model(xs...) -function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false) +function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpoint_func = wrap) if clear_cache Flux.ChainRulesCore.ignore_derivatives() do Jjama3.clear_cache!(model) @@ -57,17 +57,10 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache end for i in 1:length(model.layers) if !isnothing(opt_state) - if checkpointed - h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) - else - h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) - end + #If checkpoint_func is also just wrap, then this does nothing, but if its Zygote.checkpointed, then this is a checkpointed update + h = checkpoint_func(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) else - if checkpointed - h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask) - else - h = model.layers[i](h, model.pos, rope, mask) - end + h = checkpoint_func(wrap, model.layers[i], h, model.pos, rope, mask) end end h = model.norm(h) @@ -76,7 +69,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache return output end -(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed) +(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpoint_func = wrap) = model(tokens, nothing; clear_cache, checkpoint_func) function loss(logits, targets::AbstractArray; loss_mask = nothing) vocab_size = size(logits,1) diff --git a/src/sdpa.jl b/src/sdpa.jl new file mode 100644 index 0000000..014f307 --- /dev/null +++ b/src/sdpa.jl @@ -0,0 +1,250 @@ +#Trying out some tricks for attention. + +#Figure out where to thunk: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html + +#Will use Zygote - for testing grad correctness: +function sdpa_norrule(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T + A = softmax(batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) .+ mask; dims=1) + return batched_mul(xv, A) +end + +function ChainRulesCore.rrule(::typeof(sdpa), + xq::AbstractArray{T}, #(D, LQ, HB) + xk::AbstractArray{T}, #(D, LKV, HB) + xv::AbstractArray{T}, #(D, LKV, HB) + mask::AbstractArray{T}, #(LKV, LQ) + head_dim::Int + ) where {T} + α = sqrt(T(head_dim)) + A = softmax(((batched_mul(batched_transpose(xk), xq) ./ α) .+ mask); dims=1) #(LKV, LQ, HB) "head-batch" + y = batched_mul(xv, A) #(D, LQ, HB) + function sdpa_pullback(ȳ) + xv̄ = batched_mul(ȳ, batched_transpose(A)) #(D, LKV, HB) + Ā = batched_mul(batched_transpose(xv), ȳ) #(LKV, LQ, HB) + dM = (A .* (Ā .- (sum(A .* Ā, dims=1)))) ./ α #(LKV, LQ, HB) + xq̄ = batched_mul(xk, dM) #(D, LQ, HB) + xk̄ = batched_mul(xq, batched_transpose(dM)) #(D, LKV, HB) + return NoTangent(), xq̄, xk̄, xv̄, NoTangent(), NoTangent() + end + return y, sdpa_pullback +end + + +function keychunked_sdpa(xq::AbstractArray{T,3}, + xk::AbstractArray{T,3}, + xv::AbstractArray{T,3}, + mask::AbstractArray{T}, + head_dim::Int; + k_chunk_size::Int=256 + ) where {T<:Real} + + k_len = size(xk,2) + q_len = size(xq,2) + nbatch = size(xq,3) + + scale = one(T) / sqrt(T(head_dim)) + + + partial_max = fill!(similar(xq, 1, q_len, nbatch), -Inf) + partial_expw = fill!(similar(xq, 1, q_len, nbatch), T(0)) + partial_vals = fill!(similar(xq, head_dim, q_len, nbatch), T(0)) + + # Preallocate local buffers for each chunk + attn = fill!(similar(xq, k_chunk_size, q_len, nbatch), T(0)) + local_max = fill!(similar(xq, 1, q_len, nbatch), T(0)) + new_max = similar(local_max) + w_old = similar(local_max) + w_new = similar(local_max) + chunk_sum = similar(local_max) + valpart = fill!(similar(xq, head_dim, q_len, nbatch), T(0)) + + kstart = 1 + while kstart <= k_len + k_batch = min(k_chunk_size, k_len - kstart + 1) + xk_chunk = @view xk[:, kstart : kstart + k_batch - 1, :] + xv_chunk = @view xv[:, kstart : kstart + k_batch - 1, :] + mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + attn_view = @view attn[1:k_batch, 1:q_len, 1:nbatch] + xkT_chunk = batched_transpose(xk_chunk) + + batched_mul!(attn_view, xkT_chunk, xq, scale, 0) # attn_view = scale*(xkT_chunk*xq) + attn_view .= attn_view .+ mask_chunk # add mask + + local_max .= maximum(attn_view, dims=1) + @. new_max = max(partial_max, local_max) + @. w_old = exp(partial_max - new_max) + @. w_new = exp(local_max - new_max) + @. attn_view = exp(attn_view - local_max) + + partial_vals .= partial_vals .* w_old # Rescale old accumulators by w_old + partial_expw .= partial_expw .* w_old + + chunk_sum .= sum(attn_view, dims=1) .* w_new + partial_expw .+= chunk_sum + + batched_mul!(valpart, xv_chunk, attn_view) + valpart .= valpart .* w_new + partial_vals .+= valpart + partial_max .= new_max + kstart += k_batch + end + + y = partial_vals ./ partial_expw + return y +end + + +#Todo: use this to ignore parts of the -Inf mask triangle, since we're processing over chunks of queries. +function querychunked_sdpa( + xq::AbstractArray{T,3}, + xk::AbstractArray{T,3}, + xv::AbstractArray{T,3}, + mask::AbstractArray{T}, + head_dim::Int; + q_chunk_size::Int=128 +) where {T<:Real} + q_len = size(xq, 2) + kv_len = size(xv, 2) + nbatch = size(xq, 3) + q_chunk_size = min(q_chunk_size, q_len) + α = sqrt(T(head_dim)) + y = similar(xq) + qk_chunk = similar(xq, kv_len, q_chunk_size, nbatch) + Achunk = similar(xq, kv_len, q_chunk_size, nbatch) + qstart = 1 + while qstart <= q_len + q_batch = min(q_chunk_size, q_len - qstart + 1) + qinds = qstart:qstart+q_batch-1 + qk_chunkview = view(qk_chunk,:,1:q_batch,:) + batched_mul!(qk_chunkview,batched_transpose(xk), view(xq, :, qinds, :), 1/α) + Achunk[:,1:q_batch,:] .= softmax((qk_chunkview .+ view(mask,:,qinds)); dims=1) #(LKV, LQ, HB) "head-batch" + batched_mul!(view(y,:,qinds,:),xv, view(Achunk,:,1:q_batch,:)) #(D, LQ, HB) + qstart += q_batch + end + return y +end + +function ChainRulesCore.rrule(::typeof(querychunked_sdpa), + xq::AbstractArray{T}, #(D, LQ, HB) + xk::AbstractArray{T}, #(D, LKV, HB) + xv::AbstractArray{T}, #(D, LKV, HB) + mask::AbstractArray{T}, #(LKV, LQ) + head_dim::Int; + q_chunk_size = 128 + ) where {T} + y = querychunked_sdpa(xq, xk, xv, mask, head_dim, q_chunk_size=q_chunk_size) + function sdpa_pullback(ȳ) + k_len = size(xk, 2) + q_len = size(xq, 2) + kv_len = size(xv, 2) + nbatch = size(xq, 3) + q_chunk_size = min(q_chunk_size, q_len) + α = sqrt(T(head_dim)) + + xq̄, xk̄, xv̄ = similar(xq), fill!(similar(xk), 0), fill!(similar(xv), 0) + Achunk = similar(xq, kv_len, q_chunk_size, nbatch) + Āchunk = similar(xq, kv_len, q_chunk_size, nbatch) + dMchunk = similar(xq, kv_len, q_chunk_size, nbatch) + qk_chunk = similar(xq, kv_len, q_chunk_size, nbatch) + qstart = 1 + while qstart <= q_len + q_batch = min(q_chunk_size, q_len - qstart + 1) + qinds = qstart:qstart+q_batch-1 + ȳview = view(ȳ,:,qinds,:) + qk_chunkview = view(qk_chunk,:,1:q_batch,:) + batched_mul!(qk_chunkview,batched_transpose(xk), view(xq, :, qinds, :), 1/α) + Achunk[:,1:q_batch,:] .= softmax((qk_chunkview .+ view(mask,:,qinds)); dims=1) + batched_mul!(xv̄, ȳview, batched_transpose(view(Achunk,:,1:q_batch,:)), one(T), one(T)) + Āchunkview = view(Āchunk,:,1:q_batch,:) + batched_mul!(Āchunkview, batched_transpose(xv), ȳview) + Achunkview = view(Achunk,:,1:q_batch,:) + dMchunk[:,1:q_batch,:] .= (Achunkview .* (Āchunkview .- (sum(Achunkview .* Āchunkview, dims=1)))) ./ α #(LKV, LQ, HB) + dMchunkview = view(dMchunk,:,1:q_batch,:) + batched_mul!(xk̄, view(xq,:,qinds,:), batched_transpose(dMchunkview), one(T), one(T)) #(LKV, D, HB) + batched_mul!(view(xq̄,:,qinds,:),xk, dMchunkview) #(D, LQ, HB) + qstart += q_batch + end + return NoTangent(), xq̄, xk̄, xv̄, NoTangent(), NoTangent() + end + return y, sdpa_pullback +end + +#= +#Testing forward passes +begin + L1 = 400 #Query + L2 = 599 #Key/Value + D = 32 + HB = 80 + xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); + f(xq, xk, xv, mask, hd) = (Jjama3.sdpa(xq, xk, xv, mask, hd)); + fqc(xq, xk, xv, mask, hd) = (Jjama3.querychunked_sdpa(xq, xk, xv, mask, hd, q_chunk_size = 64)); + fkc(xq, xk, xv, mask, hd) = (Jjama3.keychunked_sdpa(xq, xk, xv, mask, hd, k_chunk_size = 64)); + + res = f(xq, xk, xv, mask, D); + qcres = fqc(xq, xk, xv, mask, D); + kcres = fkc(xq, xk, xv, mask, D); + + @assert isapprox(res, qcres) + @assert isapprox(res, kcres) + + @btime f($xq, $xk, $xv, $mask, $D) + @btime fqc($xq, $xk, $xv, $mask, $D) + @btime fkc($xq, $xk, $xv, $mask, $D) +end; +=# + + + +#= +#Testing grads +begin +L1 = 1000 +L2 = 1200 +D = 32 +HB = 80 +xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); +fnr(xq, xk, xv, mask, hd) = sum(Zygote.checkpointed(Jjama3.sdpa_norrule,xq, xk, xv, mask, hd)); +f(xq, xk, xv, mask, hd) = sum(Zygote.checkpointed(Jjama3.sdpa,xq, xk, xv, mask, hd)); +flm(xq, xk, xv, mask, hd) = sum(Jjama3.querychunked_sdpa(xq, xk, xv, mask, hd, q_chunk_size = 64)); +@time res = withgradient(f, xq, xk, xv, mask, D); +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); +@time lmres = withgradient(flm, xq, xk, xv, mask, D); + +@assert isapprox(res[1], nrres[1]) +@assert isapprox(res[2][1], nrres[2][1]) +@assert isapprox(res[2][2], nrres[2][2]) +@assert isapprox(res[2][3], nrres[2][3]) +@assert isapprox(res[1], lmres[1]) +@assert isapprox(res[2][1], lmres[2][1]) +@assert isapprox(res[2][2], lmres[2][2]) +@assert isapprox(res[2][3], lmres[2][3]) + +GC.gc() +println("normal+rrule chechpointed:") +@time res = withgradient(f, xq, xk, xv, mask, D); +@time res = withgradient(f, xq, xk, xv, mask, D); + +GC.gc() +println("normal+Zygote chechpointed:") +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); + +GC.gc() +println("chunked:") +@time lmres = withgradient(flm, xq, xk, xv, mask, D); +@time lmres = withgradient(flm, xq, xk, xv, mask, D); + + +println("btimed:") +GC.gc() +@btime res = withgradient(f, xq, xk, xv, mask, D); +GC.gc() +@btime nrres = withgradient(fnr, xq, xk, xv, mask, D); +GC.gc() +@btime lmres = withgradient(flm, xq, xk, xv, mask, D); + +true +end +=# + From e656f43df42f179570ca2466ec79d7ead6736172 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 28 Dec 2024 12:23:07 +0100 Subject: [PATCH 2/3] Passing sdpa in when the model is called --- src/layers.jl | 23 ++++++++++------------- src/model.jl | 8 ++++---- src/sampling.jl | 7 ++++--- src/sdpa.jl | 1 - 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index c3c8e40..81d9b45 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -117,12 +117,11 @@ mutable struct Attention n_kv_heads::Int head_dim::Int cache::KVCache - sdpa_func::Function end Flux.@layer Attention trainable=(wq,wv) -function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false, sdpa_func = sdpa) +function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) head_dim = dim ÷ n_heads Attention( Dense(dim => n_heads * head_dim, bias=qkv_bias), @@ -133,14 +132,13 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false, s n_heads, n_kv_heads, head_dim, - KVCache(Float32; n_kv_heads, head_dim), - sdpa_func + KVCache(Float32; n_kv_heads, head_dim) ) end repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) -function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T +function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false, sdpa_func = sdpa) where T _, seqlen, batch = size(x) xq = attn.wq(x) @@ -169,7 +167,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - output = attn.sdpa_func(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) + output = sdpa_func(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) @@ -187,18 +185,18 @@ end function TransformerBlock( dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5, qkv_bias=false, sdpa_func = sdpa + norm_eps=1f-5, qkv_bias=false ) TransformerBlock( - Attention(dim, n_heads, n_kv_heads; qkv_bias, sdpa_func), + Attention(dim, n_heads, n_kv_heads; qkv_bias), FeedForward(dim, ff_hidden_dim), RMSNorm(dim, eps=norm_eps), RMSNorm(dim, eps=norm_eps) ) end -function (block::TransformerBlock)(x, start_pos, rope, mask) - h = x + block.attention(block.attention_norm(x), start_pos, rope, mask) +function (block::TransformerBlock)(x, start_pos, rope, mask, sdpa) + h = x + block.attention(block.attention_norm(x), start_pos, rope, mask, sdpa) out = h + block.feed_forward(block.ffn_norm(h)) return out end @@ -223,11 +221,10 @@ function Transformer( qkv_bias=false, rope_theta::T=500000f0, use_scaled_rope=false, - scale_factor=8, - sdpa_func = sdpa + scale_factor=8 ) where T tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias, sdpa_func=sdpa_func) for _ in 1:n_layers) + layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) #This should probably be generated to a sane length, and then extended in the forward pass if needed. diff --git a/src/model.jl b/src/model.jl index e49b0bc..787151f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -42,7 +42,7 @@ end wrap(model, xs...) = model(xs...) -function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpoint_func = wrap) +function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa) if clear_cache Flux.ChainRulesCore.ignore_derivatives() do Jjama3.clear_cache!(model) @@ -58,9 +58,9 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache for i in 1:length(model.layers) if !isnothing(opt_state) #If checkpoint_func is also just wrap, then this does nothing, but if its Zygote.checkpointed, then this is a checkpointed update - h = checkpoint_func(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) + h = checkpoint_func(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask, sdpa_func) else - h = checkpoint_func(wrap, model.layers[i], h, model.pos, rope, mask) + h = checkpoint_func(wrap, model.layers[i], h, model.pos, rope, mask, sdpa_func) end end h = model.norm(h) @@ -69,7 +69,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache return output end -(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpoint_func = wrap) = model(tokens, nothing; clear_cache, checkpoint_func) +(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa) = model(tokens, nothing; clear_cache, checkpoint_func, sdpa_func) function loss(logits, targets::AbstractArray; loss_mask = nothing) vocab_size = size(logits,1) diff --git a/src/sampling.jl b/src/sampling.jl index 945c68c..69df2e1 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -21,7 +21,8 @@ function generate( end_token = 128010, clear_cache = true, pos_offset = 0, - device = identity + device = identity, + sdpa_func = sdpa ) where T current_len = length(initial_tokens) tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) @@ -32,10 +33,10 @@ function generate( extend_cache!(model, current_len + max_new_tokens) end input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) - logits = model(input_tokens) + logits = model(input_tokens, sdpa_func = sdpa_func) for _ in 1:max_new_tokens input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token - logits = model(input_tokens) + logits = model(input_tokens, sdpa_func = sdpa_func) next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token diff --git a/src/sdpa.jl b/src/sdpa.jl index 014f307..d2f1ac3 100644 --- a/src/sdpa.jl +++ b/src/sdpa.jl @@ -37,7 +37,6 @@ function keychunked_sdpa(xq::AbstractArray{T,3}, head_dim::Int; k_chunk_size::Int=256 ) where {T<:Real} - k_len = size(xk,2) q_len = size(xq,2) nbatch = size(xq,3) From 9dd30e034b1902f9b3977eecf4405e422cd58308 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 28 Dec 2024 13:13:08 +0100 Subject: [PATCH 3/3] Fixing key chunk when sampling --- src/model.jl | 2 +- src/sdpa.jl | 26 +++++++++++++++++--------- test/runtests.jl | 3 +++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/model.jl b/src/model.jl index 787151f..8a0b59e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -50,7 +50,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache end h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[model.pos+1:model.pos+size(tokens, 1)] - if size(h, 2) == 1 + if size(h, 2) == 1 #If there is only one new token, then a 1-by-1 mask = 0 works, via broadcasting (if the attention functions allow it) mask = Jjama3.create_mask(h) else mask = Jjama3.create_mask(h; precached_size = model.pos) diff --git a/src/sdpa.jl b/src/sdpa.jl index d2f1ac3..669f742 100644 --- a/src/sdpa.jl +++ b/src/sdpa.jl @@ -1,6 +1,6 @@ #Trying out some tricks for attention. -#Figure out where to thunk: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html +#Figure out where to thunk... #Will use Zygote - for testing grad correctness: function sdpa_norrule(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T @@ -35,14 +35,14 @@ function keychunked_sdpa(xq::AbstractArray{T,3}, xv::AbstractArray{T,3}, mask::AbstractArray{T}, head_dim::Int; - k_chunk_size::Int=256 + k_chunk_size::Int=128 ) where {T<:Real} + k_len = size(xk,2) q_len = size(xq,2) nbatch = size(xq,3) scale = one(T) / sqrt(T(head_dim)) - partial_max = fill!(similar(xq, 1, q_len, nbatch), -Inf) partial_expw = fill!(similar(xq, 1, q_len, nbatch), T(0)) @@ -62,7 +62,11 @@ function keychunked_sdpa(xq::AbstractArray{T,3}, k_batch = min(k_chunk_size, k_len - kstart + 1) xk_chunk = @view xk[:, kstart : kstart + k_batch - 1, :] xv_chunk = @view xv[:, kstart : kstart + k_batch - 1, :] - mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + if length(mask) > 1 + mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + else + mask_chunk = mask #Handles the case where the mask is 1-by-1 for sampling a single token. + end attn_view = @view attn[1:k_batch, 1:q_len, 1:nbatch] xkT_chunk = batched_transpose(xk_chunk) @@ -171,8 +175,8 @@ end #= #Testing forward passes begin - L1 = 400 #Query - L2 = 599 #Key/Value + L1 = 872 #Query + L2 = 267 #Key/Value D = 32 HB = 80 xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); @@ -187,9 +191,13 @@ begin @assert isapprox(res, qcres) @assert isapprox(res, kcres) - @btime f($xq, $xk, $xv, $mask, $D) - @btime fqc($xq, $xk, $xv, $mask, $D) - @btime fkc($xq, $xk, $xv, $mask, $D) + @show size(res) + @show size(kcres) + + + #@btime f($xq, $xk, $xv, $mask, $D) + #@btime fqc($xq, $xk, $xv, $mask, $D) + #@btime fkc($xq, $xk, $xv, $mask, $D) end; =# diff --git a/test/runtests.jl b/test/runtests.jl index 5579a45..4fcef7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,9 @@ using Downloads model = load_llama3_from_safetensors([model_path], config) AAs = collect(">ACDEFGHIKLMNPQRSTVWY.") @test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.keychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.querychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.sdpa_norrule) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] end end