diff --git a/Project.toml b/Project.toml index b96873f..79300c8 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ MetalExt = "Metal" [compat] ChainRulesCore = "1.25.0" -Flux = "0.14, 0.15" +Flux = "0.14, 0.15, 0.16" HuggingFaceTokenizers = "0.1" LogitSamplers = "0.1.0" LowRankLayers = "0.1.2" diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 71aa9f5..e40f6af 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -6,7 +6,7 @@ using LinearAlgebra using NNlib using LogitSamplers using LowRankLayers -#using ChainRulesCore +using ChainRulesCore include("cache.jl") export KVCache @@ -23,7 +23,7 @@ export rerope_cache! export scrape_cache export append_cache! -#include("sdpa.jl") +include("sdpa.jl") include("model.jl") export forward_loss diff --git a/src/layers.jl b/src/layers.jl index 53911fa..81d9b45 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -101,6 +101,12 @@ function unrope(rope, x) ) end +#Scaled dot product attention +function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T + A = softmax(batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) .+ mask; dims=1) + return batched_mul(xv, A) +end + mutable struct Attention wq::AnyDense wk::AnyDense @@ -126,20 +132,13 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) n_heads, n_kv_heads, head_dim, - KVCache(Float32; n_kv_heads, head_dim), + KVCache(Float32; n_kv_heads, head_dim) ) end repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) -function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T - scores = batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) - scores = scores .+ mask - sm_scores = softmax(scores; dims=1) - return batched_mul(xv, sm_scores) -end - -function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T +function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false, sdpa_func = sdpa) where T _, seqlen, batch = size(x) xq = attn.wq(x) @@ -168,7 +167,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - output = sdpa(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) + output = sdpa_func(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) @@ -177,9 +176,6 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing return proj end - - - struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} attention::A feed_forward::F @@ -189,7 +185,7 @@ end function TransformerBlock( dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5, qkv_bias=false, + norm_eps=1f-5, qkv_bias=false ) TransformerBlock( Attention(dim, n_heads, n_kv_heads; qkv_bias), @@ -199,15 +195,14 @@ function TransformerBlock( ) end -function (block::TransformerBlock)(x, start_pos, rope, mask=nothing) - h = x + block.attention(block.attention_norm(x), start_pos, rope, mask) +function (block::TransformerBlock)(x, start_pos, rope, mask, sdpa) + h = x + block.attention(block.attention_norm(x), start_pos, rope, mask, sdpa) out = h + block.feed_forward(block.ffn_norm(h)) return out end Flux.@layer TransformerBlock trainable=(attention,) - mutable struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE} tok_embeddings::E layers::B @@ -226,7 +221,7 @@ function Transformer( qkv_bias=false, rope_theta::T=500000f0, use_scaled_rope=false, - scale_factor=8, + scale_factor=8 ) where T tok_embeddings = Flux.Embedding(vocab_size => dim) layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) diff --git a/src/model.jl b/src/model.jl index dbbebdb..8a0b59e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -42,7 +42,7 @@ end wrap(model, xs...) = model(xs...) -function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false) +function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa) if clear_cache Flux.ChainRulesCore.ignore_derivatives() do Jjama3.clear_cache!(model) @@ -50,24 +50,17 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache end h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[model.pos+1:model.pos+size(tokens, 1)] - if size(h, 2) == 1 + if size(h, 2) == 1 #If there is only one new token, then a 1-by-1 mask = 0 works, via broadcasting (if the attention functions allow it) mask = Jjama3.create_mask(h) else mask = Jjama3.create_mask(h; precached_size = model.pos) end for i in 1:length(model.layers) if !isnothing(opt_state) - if checkpointed - h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) - else - h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) - end + #If checkpoint_func is also just wrap, then this does nothing, but if its Zygote.checkpointed, then this is a checkpointed update + h = checkpoint_func(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask, sdpa_func) else - if checkpointed - h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask) - else - h = model.layers[i](h, model.pos, rope, mask) - end + h = checkpoint_func(wrap, model.layers[i], h, model.pos, rope, mask, sdpa_func) end end h = model.norm(h) @@ -76,7 +69,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache return output end -(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed) +(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa) = model(tokens, nothing; clear_cache, checkpoint_func, sdpa_func) function loss(logits, targets::AbstractArray; loss_mask = nothing) vocab_size = size(logits,1) diff --git a/src/sampling.jl b/src/sampling.jl index 945c68c..69df2e1 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -21,7 +21,8 @@ function generate( end_token = 128010, clear_cache = true, pos_offset = 0, - device = identity + device = identity, + sdpa_func = sdpa ) where T current_len = length(initial_tokens) tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) @@ -32,10 +33,10 @@ function generate( extend_cache!(model, current_len + max_new_tokens) end input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) - logits = model(input_tokens) + logits = model(input_tokens, sdpa_func = sdpa_func) for _ in 1:max_new_tokens input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token - logits = model(input_tokens) + logits = model(input_tokens, sdpa_func = sdpa_func) next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token diff --git a/src/sdpa.jl b/src/sdpa.jl new file mode 100644 index 0000000..669f742 --- /dev/null +++ b/src/sdpa.jl @@ -0,0 +1,257 @@ +#Trying out some tricks for attention. + +#Figure out where to thunk... + +#Will use Zygote - for testing grad correctness: +function sdpa_norrule(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T + A = softmax(batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) .+ mask; dims=1) + return batched_mul(xv, A) +end + +function ChainRulesCore.rrule(::typeof(sdpa), + xq::AbstractArray{T}, #(D, LQ, HB) + xk::AbstractArray{T}, #(D, LKV, HB) + xv::AbstractArray{T}, #(D, LKV, HB) + mask::AbstractArray{T}, #(LKV, LQ) + head_dim::Int + ) where {T} + α = sqrt(T(head_dim)) + A = softmax(((batched_mul(batched_transpose(xk), xq) ./ α) .+ mask); dims=1) #(LKV, LQ, HB) "head-batch" + y = batched_mul(xv, A) #(D, LQ, HB) + function sdpa_pullback(ȳ) + xv̄ = batched_mul(ȳ, batched_transpose(A)) #(D, LKV, HB) + Ā = batched_mul(batched_transpose(xv), ȳ) #(LKV, LQ, HB) + dM = (A .* (Ā .- (sum(A .* Ā, dims=1)))) ./ α #(LKV, LQ, HB) + xq̄ = batched_mul(xk, dM) #(D, LQ, HB) + xk̄ = batched_mul(xq, batched_transpose(dM)) #(D, LKV, HB) + return NoTangent(), xq̄, xk̄, xv̄, NoTangent(), NoTangent() + end + return y, sdpa_pullback +end + + +function keychunked_sdpa(xq::AbstractArray{T,3}, + xk::AbstractArray{T,3}, + xv::AbstractArray{T,3}, + mask::AbstractArray{T}, + head_dim::Int; + k_chunk_size::Int=128 + ) where {T<:Real} + + k_len = size(xk,2) + q_len = size(xq,2) + nbatch = size(xq,3) + + scale = one(T) / sqrt(T(head_dim)) + + partial_max = fill!(similar(xq, 1, q_len, nbatch), -Inf) + partial_expw = fill!(similar(xq, 1, q_len, nbatch), T(0)) + partial_vals = fill!(similar(xq, head_dim, q_len, nbatch), T(0)) + + # Preallocate local buffers for each chunk + attn = fill!(similar(xq, k_chunk_size, q_len, nbatch), T(0)) + local_max = fill!(similar(xq, 1, q_len, nbatch), T(0)) + new_max = similar(local_max) + w_old = similar(local_max) + w_new = similar(local_max) + chunk_sum = similar(local_max) + valpart = fill!(similar(xq, head_dim, q_len, nbatch), T(0)) + + kstart = 1 + while kstart <= k_len + k_batch = min(k_chunk_size, k_len - kstart + 1) + xk_chunk = @view xk[:, kstart : kstart + k_batch - 1, :] + xv_chunk = @view xv[:, kstart : kstart + k_batch - 1, :] + if length(mask) > 1 + mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + else + mask_chunk = mask #Handles the case where the mask is 1-by-1 for sampling a single token. + end + attn_view = @view attn[1:k_batch, 1:q_len, 1:nbatch] + xkT_chunk = batched_transpose(xk_chunk) + + batched_mul!(attn_view, xkT_chunk, xq, scale, 0) # attn_view = scale*(xkT_chunk*xq) + attn_view .= attn_view .+ mask_chunk # add mask + + local_max .= maximum(attn_view, dims=1) + @. new_max = max(partial_max, local_max) + @. w_old = exp(partial_max - new_max) + @. w_new = exp(local_max - new_max) + @. attn_view = exp(attn_view - local_max) + + partial_vals .= partial_vals .* w_old # Rescale old accumulators by w_old + partial_expw .= partial_expw .* w_old + + chunk_sum .= sum(attn_view, dims=1) .* w_new + partial_expw .+= chunk_sum + + batched_mul!(valpart, xv_chunk, attn_view) + valpart .= valpart .* w_new + partial_vals .+= valpart + partial_max .= new_max + kstart += k_batch + end + + y = partial_vals ./ partial_expw + return y +end + + +#Todo: use this to ignore parts of the -Inf mask triangle, since we're processing over chunks of queries. +function querychunked_sdpa( + xq::AbstractArray{T,3}, + xk::AbstractArray{T,3}, + xv::AbstractArray{T,3}, + mask::AbstractArray{T}, + head_dim::Int; + q_chunk_size::Int=128 +) where {T<:Real} + q_len = size(xq, 2) + kv_len = size(xv, 2) + nbatch = size(xq, 3) + q_chunk_size = min(q_chunk_size, q_len) + α = sqrt(T(head_dim)) + y = similar(xq) + qk_chunk = similar(xq, kv_len, q_chunk_size, nbatch) + Achunk = similar(xq, kv_len, q_chunk_size, nbatch) + qstart = 1 + while qstart <= q_len + q_batch = min(q_chunk_size, q_len - qstart + 1) + qinds = qstart:qstart+q_batch-1 + qk_chunkview = view(qk_chunk,:,1:q_batch,:) + batched_mul!(qk_chunkview,batched_transpose(xk), view(xq, :, qinds, :), 1/α) + Achunk[:,1:q_batch,:] .= softmax((qk_chunkview .+ view(mask,:,qinds)); dims=1) #(LKV, LQ, HB) "head-batch" + batched_mul!(view(y,:,qinds,:),xv, view(Achunk,:,1:q_batch,:)) #(D, LQ, HB) + qstart += q_batch + end + return y +end + +function ChainRulesCore.rrule(::typeof(querychunked_sdpa), + xq::AbstractArray{T}, #(D, LQ, HB) + xk::AbstractArray{T}, #(D, LKV, HB) + xv::AbstractArray{T}, #(D, LKV, HB) + mask::AbstractArray{T}, #(LKV, LQ) + head_dim::Int; + q_chunk_size = 128 + ) where {T} + y = querychunked_sdpa(xq, xk, xv, mask, head_dim, q_chunk_size=q_chunk_size) + function sdpa_pullback(ȳ) + k_len = size(xk, 2) + q_len = size(xq, 2) + kv_len = size(xv, 2) + nbatch = size(xq, 3) + q_chunk_size = min(q_chunk_size, q_len) + α = sqrt(T(head_dim)) + + xq̄, xk̄, xv̄ = similar(xq), fill!(similar(xk), 0), fill!(similar(xv), 0) + Achunk = similar(xq, kv_len, q_chunk_size, nbatch) + Āchunk = similar(xq, kv_len, q_chunk_size, nbatch) + dMchunk = similar(xq, kv_len, q_chunk_size, nbatch) + qk_chunk = similar(xq, kv_len, q_chunk_size, nbatch) + qstart = 1 + while qstart <= q_len + q_batch = min(q_chunk_size, q_len - qstart + 1) + qinds = qstart:qstart+q_batch-1 + ȳview = view(ȳ,:,qinds,:) + qk_chunkview = view(qk_chunk,:,1:q_batch,:) + batched_mul!(qk_chunkview,batched_transpose(xk), view(xq, :, qinds, :), 1/α) + Achunk[:,1:q_batch,:] .= softmax((qk_chunkview .+ view(mask,:,qinds)); dims=1) + batched_mul!(xv̄, ȳview, batched_transpose(view(Achunk,:,1:q_batch,:)), one(T), one(T)) + Āchunkview = view(Āchunk,:,1:q_batch,:) + batched_mul!(Āchunkview, batched_transpose(xv), ȳview) + Achunkview = view(Achunk,:,1:q_batch,:) + dMchunk[:,1:q_batch,:] .= (Achunkview .* (Āchunkview .- (sum(Achunkview .* Āchunkview, dims=1)))) ./ α #(LKV, LQ, HB) + dMchunkview = view(dMchunk,:,1:q_batch,:) + batched_mul!(xk̄, view(xq,:,qinds,:), batched_transpose(dMchunkview), one(T), one(T)) #(LKV, D, HB) + batched_mul!(view(xq̄,:,qinds,:),xk, dMchunkview) #(D, LQ, HB) + qstart += q_batch + end + return NoTangent(), xq̄, xk̄, xv̄, NoTangent(), NoTangent() + end + return y, sdpa_pullback +end + +#= +#Testing forward passes +begin + L1 = 872 #Query + L2 = 267 #Key/Value + D = 32 + HB = 80 + xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); + f(xq, xk, xv, mask, hd) = (Jjama3.sdpa(xq, xk, xv, mask, hd)); + fqc(xq, xk, xv, mask, hd) = (Jjama3.querychunked_sdpa(xq, xk, xv, mask, hd, q_chunk_size = 64)); + fkc(xq, xk, xv, mask, hd) = (Jjama3.keychunked_sdpa(xq, xk, xv, mask, hd, k_chunk_size = 64)); + + res = f(xq, xk, xv, mask, D); + qcres = fqc(xq, xk, xv, mask, D); + kcres = fkc(xq, xk, xv, mask, D); + + @assert isapprox(res, qcres) + @assert isapprox(res, kcres) + + @show size(res) + @show size(kcres) + + + #@btime f($xq, $xk, $xv, $mask, $D) + #@btime fqc($xq, $xk, $xv, $mask, $D) + #@btime fkc($xq, $xk, $xv, $mask, $D) +end; +=# + + + +#= +#Testing grads +begin +L1 = 1000 +L2 = 1200 +D = 32 +HB = 80 +xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); +fnr(xq, xk, xv, mask, hd) = sum(Zygote.checkpointed(Jjama3.sdpa_norrule,xq, xk, xv, mask, hd)); +f(xq, xk, xv, mask, hd) = sum(Zygote.checkpointed(Jjama3.sdpa,xq, xk, xv, mask, hd)); +flm(xq, xk, xv, mask, hd) = sum(Jjama3.querychunked_sdpa(xq, xk, xv, mask, hd, q_chunk_size = 64)); +@time res = withgradient(f, xq, xk, xv, mask, D); +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); +@time lmres = withgradient(flm, xq, xk, xv, mask, D); + +@assert isapprox(res[1], nrres[1]) +@assert isapprox(res[2][1], nrres[2][1]) +@assert isapprox(res[2][2], nrres[2][2]) +@assert isapprox(res[2][3], nrres[2][3]) +@assert isapprox(res[1], lmres[1]) +@assert isapprox(res[2][1], lmres[2][1]) +@assert isapprox(res[2][2], lmres[2][2]) +@assert isapprox(res[2][3], lmres[2][3]) + +GC.gc() +println("normal+rrule chechpointed:") +@time res = withgradient(f, xq, xk, xv, mask, D); +@time res = withgradient(f, xq, xk, xv, mask, D); + +GC.gc() +println("normal+Zygote chechpointed:") +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); +@time nrres = withgradient(fnr, xq, xk, xv, mask, D); + +GC.gc() +println("chunked:") +@time lmres = withgradient(flm, xq, xk, xv, mask, D); +@time lmres = withgradient(flm, xq, xk, xv, mask, D); + + +println("btimed:") +GC.gc() +@btime res = withgradient(f, xq, xk, xv, mask, D); +GC.gc() +@btime nrres = withgradient(fnr, xq, xk, xv, mask, D); +GC.gc() +@btime lmres = withgradient(flm, xq, xk, xv, mask, D); + +true +end +=# + diff --git a/test/runtests.jl b/test/runtests.jl index 5579a45..4fcef7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,9 @@ using Downloads model = load_llama3_from_safetensors([model_path], config) AAs = collect(">ACDEFGHIKLMNPQRSTVWY.") @test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.keychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.querychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.sdpa_norrule) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] end end