diff --git a/src/model.jl b/src/model.jl index 787151f..8a0b59e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -50,7 +50,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache end h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[model.pos+1:model.pos+size(tokens, 1)] - if size(h, 2) == 1 + if size(h, 2) == 1 #If there is only one new token, then a 1-by-1 mask = 0 works, via broadcasting (if the attention functions allow it) mask = Jjama3.create_mask(h) else mask = Jjama3.create_mask(h; precached_size = model.pos) diff --git a/src/sdpa.jl b/src/sdpa.jl index d2f1ac3..669f742 100644 --- a/src/sdpa.jl +++ b/src/sdpa.jl @@ -1,6 +1,6 @@ #Trying out some tricks for attention. -#Figure out where to thunk: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html +#Figure out where to thunk... #Will use Zygote - for testing grad correctness: function sdpa_norrule(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T @@ -35,14 +35,14 @@ function keychunked_sdpa(xq::AbstractArray{T,3}, xv::AbstractArray{T,3}, mask::AbstractArray{T}, head_dim::Int; - k_chunk_size::Int=256 + k_chunk_size::Int=128 ) where {T<:Real} + k_len = size(xk,2) q_len = size(xq,2) nbatch = size(xq,3) scale = one(T) / sqrt(T(head_dim)) - partial_max = fill!(similar(xq, 1, q_len, nbatch), -Inf) partial_expw = fill!(similar(xq, 1, q_len, nbatch), T(0)) @@ -62,7 +62,11 @@ function keychunked_sdpa(xq::AbstractArray{T,3}, k_batch = min(k_chunk_size, k_len - kstart + 1) xk_chunk = @view xk[:, kstart : kstart + k_batch - 1, :] xv_chunk = @view xv[:, kstart : kstart + k_batch - 1, :] - mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + if length(mask) > 1 + mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :] + else + mask_chunk = mask #Handles the case where the mask is 1-by-1 for sampling a single token. + end attn_view = @view attn[1:k_batch, 1:q_len, 1:nbatch] xkT_chunk = batched_transpose(xk_chunk) @@ -171,8 +175,8 @@ end #= #Testing forward passes begin - L1 = 400 #Query - L2 = 599 #Key/Value + L1 = 872 #Query + L2 = 267 #Key/Value D = 32 HB = 80 xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1); @@ -187,9 +191,13 @@ begin @assert isapprox(res, qcres) @assert isapprox(res, kcres) - @btime f($xq, $xk, $xv, $mask, $D) - @btime fqc($xq, $xk, $xv, $mask, $D) - @btime fkc($xq, $xk, $xv, $mask, $D) + @show size(res) + @show size(kcres) + + + #@btime f($xq, $xk, $xv, $mask, $D) + #@btime fqc($xq, $xk, $xv, $mask, $D) + #@btime fkc($xq, $xk, $xv, $mask, $D) end; =# diff --git a/test/runtests.jl b/test/runtests.jl index 5579a45..4fcef7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,9 @@ using Downloads model = load_llama3_from_safetensors([model_path], config) AAs = collect(">ACDEFGHIKLMNPQRSTVWY.") @test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.keychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.querychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + @test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.sdpa_norrule) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] end end