Skip to content

Commit

Permalink
Merge pull request #23 from MurrellGroup/sdpa
Browse files Browse the repository at this point in the history
SDPA and memory saving tricks
  • Loading branch information
murrellb authored Dec 28, 2024
2 parents 203938e + 43922d5 commit 90447e3
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ MetalExt = "Metal"

[compat]
ChainRulesCore = "1.25.0"
Flux = "0.14, 0.15"
Flux = "0.14, 0.15, 0.16"
HuggingFaceTokenizers = "0.1"
LogitSamplers = "0.1.0"
LowRankLayers = "0.1.2"
Expand Down
4 changes: 2 additions & 2 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra
using NNlib
using LogitSamplers
using LowRankLayers
#using ChainRulesCore
using ChainRulesCore

include("cache.jl")
export KVCache
Expand All @@ -23,7 +23,7 @@ export rerope_cache!
export scrape_cache
export append_cache!

#include("sdpa.jl")
include("sdpa.jl")

include("model.jl")
export forward_loss
Expand Down
31 changes: 13 additions & 18 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ function unrope(rope, x)
)
end

#Scaled dot product attention
function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T
A = softmax(batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) .+ mask; dims=1)
return batched_mul(xv, A)
end

mutable struct Attention
wq::AnyDense
wk::AnyDense
Expand All @@ -126,20 +132,13 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false)
n_heads,
n_kv_heads,
head_dim,
KVCache(Float32; n_kv_heads, head_dim),
KVCache(Float32; n_kv_heads, head_dim)
)
end

repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1)

function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T
scores = batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim))
scores = scores .+ mask
sm_scores = softmax(scores; dims=1)
return batched_mul(xv, sm_scores)
end

function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T
function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false, sdpa_func = sdpa) where T
_, seqlen, batch = size(x)

xq = attn.wq(x)
Expand Down Expand Up @@ -168,7 +167,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing
xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch)
xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch)

output = sdpa(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim)
output = sdpa_func(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim)

e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch))
p_output = permutedims(e_output, (1,3,2,4))
Expand All @@ -177,9 +176,6 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing
return proj
end




struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm}
attention::A
feed_forward::F
Expand All @@ -189,7 +185,7 @@ end

function TransformerBlock(
dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim;
norm_eps=1f-5, qkv_bias=false,
norm_eps=1f-5, qkv_bias=false
)
TransformerBlock(
Attention(dim, n_heads, n_kv_heads; qkv_bias),
Expand All @@ -199,15 +195,14 @@ function TransformerBlock(
)
end

function (block::TransformerBlock)(x, start_pos, rope, mask=nothing)
h = x + block.attention(block.attention_norm(x), start_pos, rope, mask)
function (block::TransformerBlock)(x, start_pos, rope, mask, sdpa)
h = x + block.attention(block.attention_norm(x), start_pos, rope, mask, sdpa)
out = h + block.feed_forward(block.ffn_norm(h))
return out
end

Flux.@layer TransformerBlock trainable=(attention,)


mutable struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE}
tok_embeddings::E
layers::B
Expand All @@ -226,7 +221,7 @@ function Transformer(
qkv_bias=false,
rope_theta::T=500000f0,
use_scaled_rope=false,
scale_factor=8,
scale_factor=8
) where T
tok_embeddings = Flux.Embedding(vocab_size => dim)
layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers)
Expand Down
19 changes: 6 additions & 13 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,25 @@ end


wrap(model, xs...) = model(xs...)
function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false)
function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
Jjama3.clear_cache!(model)
end
end
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
if size(h, 2) == 1 #If there is only one new token, then a 1-by-1 mask = 0 works, via broadcasting (if the attention functions allow it)
mask = Jjama3.create_mask(h)
else
mask = Jjama3.create_mask(h; precached_size = model.pos)
end
for i in 1:length(model.layers)
if !isnothing(opt_state)
if checkpointed
h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)
else
h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)
end
#If checkpoint_func is also just wrap, then this does nothing, but if its Zygote.checkpointed, then this is a checkpointed update
h = checkpoint_func(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask, sdpa_func)
else
if checkpointed
h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask)
else
h = model.layers[i](h, model.pos, rope, mask)
end
h = checkpoint_func(wrap, model.layers[i], h, model.pos, rope, mask, sdpa_func)
end
end
h = model.norm(h)
Expand All @@ -76,7 +69,7 @@ function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache
return output
end

(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed)
(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpoint_func = wrap, sdpa_func = sdpa) = model(tokens, nothing; clear_cache, checkpoint_func, sdpa_func)

function loss(logits, targets::AbstractArray; loss_mask = nothing)
vocab_size = size(logits,1)
Expand Down
7 changes: 4 additions & 3 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ function generate(
end_token = 128010,
clear_cache = true,
pos_offset = 0,
device = identity
device = identity,
sdpa_func = sdpa
) where T
current_len = length(initial_tokens)
tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens))
Expand All @@ -32,10 +33,10 @@ function generate(
extend_cache!(model, current_len + max_new_tokens)
end
input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1)
logits = model(input_tokens)
logits = model(input_tokens, sdpa_func = sdpa_func)
for _ in 1:max_new_tokens
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
logits = model(input_tokens)
logits = model(input_tokens, sdpa_func = sdpa_func)
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
Expand Down
Loading

0 comments on commit 90447e3

Please sign in to comment.