Skip to content

Commit

Permalink
Merge pull request #24 from MurrellGroup/sdpa2
Browse files Browse the repository at this point in the history
Fixing key chunk when sampling
  • Loading branch information
murrellb authored Dec 28, 2024
2 parents cf06f65 + 9dd30e0 commit 43922d5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache
end
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
if size(h, 2) == 1 #If there is only one new token, then a 1-by-1 mask = 0 works, via broadcasting (if the attention functions allow it)
mask = Jjama3.create_mask(h)
else
mask = Jjama3.create_mask(h; precached_size = model.pos)
Expand Down
26 changes: 17 additions & 9 deletions src/sdpa.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#Trying out some tricks for attention.

#Figure out where to thunk: https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html
#Figure out where to thunk...

#Will use Zygote - for testing grad correctness:
function sdpa_norrule(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T
Expand Down Expand Up @@ -35,14 +35,14 @@ function keychunked_sdpa(xq::AbstractArray{T,3},
xv::AbstractArray{T,3},
mask::AbstractArray{T},
head_dim::Int;
k_chunk_size::Int=256
k_chunk_size::Int=128
) where {T<:Real}

k_len = size(xk,2)
q_len = size(xq,2)
nbatch = size(xq,3)

scale = one(T) / sqrt(T(head_dim))


partial_max = fill!(similar(xq, 1, q_len, nbatch), -Inf)
partial_expw = fill!(similar(xq, 1, q_len, nbatch), T(0))
Expand All @@ -62,7 +62,11 @@ function keychunked_sdpa(xq::AbstractArray{T,3},
k_batch = min(k_chunk_size, k_len - kstart + 1)
xk_chunk = @view xk[:, kstart : kstart + k_batch - 1, :]
xv_chunk = @view xv[:, kstart : kstart + k_batch - 1, :]
mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :]
if length(mask) > 1
mask_chunk = @view mask[kstart : kstart + k_batch - 1, :, :]

Check warning on line 66 in src/sdpa.jl

View check run for this annotation

Codecov / codecov/patch

src/sdpa.jl#L66

Added line #L66 was not covered by tests
else
mask_chunk = mask #Handles the case where the mask is 1-by-1 for sampling a single token.
end
attn_view = @view attn[1:k_batch, 1:q_len, 1:nbatch]
xkT_chunk = batched_transpose(xk_chunk)

Expand Down Expand Up @@ -171,8 +175,8 @@ end
#=
#Testing forward passes
begin
L1 = 400 #Query
L2 = 599 #Key/Value
L1 = 872 #Query
L2 = 267 #Key/Value
D = 32
HB = 80
xq, xk, xv, mask = randn(Float32, D, L1, HB), randn(Float32, D, L2, HB), randn(Float32, D, L2, HB), zeros(Float32, L2, L1);
Expand All @@ -187,9 +191,13 @@ begin
@assert isapprox(res, qcres)
@assert isapprox(res, kcres)
@btime f($xq, $xk, $xv, $mask, $D)
@btime fqc($xq, $xk, $xv, $mask, $D)
@btime fkc($xq, $xk, $xv, $mask, $D)
@show size(res)
@show size(kcres)
#@btime f($xq, $xk, $xv, $mask, $D)
#@btime fqc($xq, $xk, $xv, $mask, $D)
#@btime fkc($xq, $xk, $xv, $mask, $D)
end;
=#

Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ using Downloads
model = load_llama3_from_safetensors([model_path], config)
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")
@test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
@test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.keychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
@test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.querychunked_sdpa) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
@test generate(model, encode(AAs, ">"), end_token = 22, sdpa_func = Jjama3.sdpa_norrule) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
end

end

0 comments on commit 43922d5

Please sign in to comment.