From a6a7bbdce148c0b0750bd292b7e1bdab47d63dca Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 24 Dec 2024 16:37:31 +0100 Subject: [PATCH 1/2] Eager updates, and loss refactor --- src/Jjama3.jl | 1 + src/model.jl | 73 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 693c426..930cabb 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -28,6 +28,7 @@ export rerope_cache! include("model.jl") export forward_loss export forward_inference +export loss include("sampling.jl") export top_pk_sampler diff --git a/src/model.jl b/src/model.jl index aaf8d62..ec8f96a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -15,16 +15,60 @@ function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractF end end -function (model::Transformer)(tokens::AbstractArray{Int}) + +function masked_agg(ce, mask) + if mask !== nothing + ce = ce .* mask + end + return sum(ce)/sum(mask) +end + +#Hoping this will wind up in Zygote.jl +""" + eager_update!(state, model, update!) + +Updates params during the backward pass, saving memory. + +f(model, xs...) = model(xs...) +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +""" +function eager_update!(state, model, update!) + function update_hook(dmodel) + update!(state, model, dmodel) + return nothing + end + return Flux.Zygote.hook(update_hook, model) +end + + +wrap(model, xs...) = model(xs...) +function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false) + if clear_cache + Flux.ChainRulesCore.ignore_derivatives() do + Jjama3.clear_cache!(model) + end + end h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[model.pos+1:model.pos+size(tokens, 1)] if size(h, 2) == 1 - mask = create_mask(h) + mask = Jjama3.create_mask(h) else - mask = create_mask(h; precached_size = model.pos) + mask = Jjama3.create_mask(h; precached_size = model.pos) end - for layer in model.layers - h = layer(h, model.pos, rope, mask) + for i in 1:length(model.layers) + if !isnothing(opt_state) + if checkpointed + h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) + else + h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) + end + else + if checkpointed + h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask) + else + h = model.layers[i](h, model.pos, rope, mask) + end + end end h = model.norm(h) output = model.output(h) @@ -32,22 +76,10 @@ function (model::Transformer)(tokens::AbstractArray{Int}) return output end -function masked_agg(ce, mask) - if mask !== nothing - ce = ce .* mask - end - return sum(ce)/sum(mask) -end +(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed) -function forward_loss(model::Transformer, inputs::AbstractArray, - targets::AbstractArray; clear_cache = true, loss_mask = nothing) - if clear_cache - Flux.ChainRulesCore.ignore_derivatives() do - clear_cache!(model) - end - end - logits = model(inputs) - vocab_size = size(model.tok_embeddings.weight, 2) +function loss(logits, targets::AbstractArray; loss_mask = nothing) + vocab_size = size(logits,1) gt = Flux.onehotbatch(targets, 1:vocab_size) if loss_mask !== nothing loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask)) @@ -59,3 +91,4 @@ end # compat forward_inference(model, args...) = model(args...) +forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(forward(model, inputs; clear_cache = clear_cache), targets; loss_mask = loss_mask) From f60900f56c8b47fb64cfafd62c92ccf4502ecaed Mon Sep 17 00:00:00 2001 From: murrellb Date: Thu, 26 Dec 2024 16:02:32 +0100 Subject: [PATCH 2/2] Splitting out sdpa --- src/layers.jl | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index fbb48aa..7b2ff9f 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -132,6 +132,13 @@ end repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) +function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T + scores = batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim)) + scores = scores .+ mask + sm_scores = softmax(scores; dims=1) + return batched_mul(xv, sm_scores) +end + function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T _, seqlen, batch = size(x) @@ -161,11 +168,8 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim)) - scores = scores .+ mask - sm_scores = softmax(scores; dims=1) - - output = batched_mul(xv_for_attn, sm_scores) + output = sdpa(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim) + e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) r_output = reshape(p_output, (attn.n_heads * attn.head_dim, seqlen, batch)) @@ -174,6 +178,8 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing end + + struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} attention::A feed_forward::F