Skip to content

Commit

Permalink
Merge pull request #20 from MurrellGroup/eager
Browse files Browse the repository at this point in the history
Eager updates, and loss refactor
  • Loading branch information
murrellb authored Dec 26, 2024
2 parents cb47d34 + e2bac27 commit 355af14
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 45 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["murrellb <[email protected]> and contributors"]
version = "1.1.0-DEV"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,19 +23,19 @@ HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/Hu
MetalExt = "Metal"

[compat]
Accessors = "0.1.38"
ChainRulesCore = "1.25.0"
Flux = "0.14, 0.15"
LogitSamplers = "0.1"
LowRankLayers = "0.1"
LogitSamplers = "0.1.0"
LowRankLayers = "0.1.2"
Metal = "1"
NNlib = "0.9"
SafeTensors = "1"
julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Downloads", "JSON3"]
4 changes: 4 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LinearAlgebra
using NNlib
using LogitSamplers
using LowRankLayers
#using ChainRulesCore

using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer

Expand All @@ -27,9 +28,12 @@ export rerope_cache!
export scrape_cache
export append_cache!

#include("sdpa.jl")

include("model.jl")
export forward_loss
export forward_inference
export loss

include("sampling.jl")
export top_pk_sampler
Expand Down
28 changes: 17 additions & 11 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ function unrope(rope, x)
)
end

struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache}
wq::Q
wk::K
wv::V
wo::O
mutable struct Attention
wq::AnyDense
wk::AnyDense
wv::AnyDense
wo::AnyDense
dim::Int
n_heads::Int
n_kv_heads::Int
head_dim::Int
cache::C
cache::KVCache
end

Flux.@layer Attention trainable=(wq,wv)
Expand All @@ -132,6 +132,13 @@ end

repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1)

function sdpa(xq::AbstractArray{T}, xk::AbstractArray{T}, xv::AbstractArray{T}, mask::AbstractArray{T}, head_dim::Int) where T
scores = batched_mul(batched_transpose(xk), xq) / sqrt(T(head_dim))
scores = scores .+ mask
sm_scores = softmax(scores; dims=1)
return batched_mul(xv, sm_scores)
end

function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T
_, seqlen, batch = size(x)

Expand Down Expand Up @@ -161,11 +168,8 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing
xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch)
xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch)

scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim))
scores = scores .+ mask
sm_scores = softmax(scores; dims=1)

output = batched_mul(xv_for_attn, sm_scores)
output = sdpa(xq_for_attn, xk_for_attn, xv_for_attn, mask, attn.head_dim)

e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch))
p_output = permutedims(e_output, (1,3,2,4))
r_output = reshape(p_output, (attn.n_heads * attn.head_dim, seqlen, batch))
Expand All @@ -174,6 +178,8 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing
end




struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm}
attention::A
feed_forward::F
Expand Down
73 changes: 53 additions & 20 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,71 @@ function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractF
end
end

function (model::Transformer)(tokens::AbstractArray{Int})

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask
end
return sum(ce)/sum(mask)
end

#Hoping this will wind up in Zygote.jl
"""
eager_update!(state, model, update!)
Updates params during the backward pass, saving memory.
f(model, xs...) = model(xs...)
h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
"""
function eager_update!(state, model, update!)
function update_hook(dmodel)
update!(state, model, dmodel)
return nothing
end
return Flux.Zygote.hook(update_hook, model)
end


wrap(model, xs...) = model(xs...)
function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
Jjama3.clear_cache!(model)
end
end
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
mask = create_mask(h)
mask = Jjama3.create_mask(h)
else
mask = create_mask(h; precached_size = model.pos)
mask = Jjama3.create_mask(h; precached_size = model.pos)
end
for layer in model.layers
h = layer(h, model.pos, rope, mask)
for i in 1:length(model.layers)
if !isnothing(opt_state)
if checkpointed
h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)
else
h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)
end
else
if checkpointed
h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask)
else
h = model.layers[i](h, model.pos, rope, mask)
end
end
end
h = model.norm(h)
output = model.output(h)
model.pos += size(tokens, 1)
return output
end

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask
end
return sum(ce)/sum(mask)
end
(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed)

function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
clear_cache!(model)
end
end
logits = model(inputs)
vocab_size = size(model.tok_embeddings.weight, 2)
function loss(logits, targets::AbstractArray; loss_mask = nothing)
vocab_size = size(logits,1)
gt = Flux.onehotbatch(targets, 1:vocab_size)
if loss_mask !== nothing
loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask))
Expand All @@ -59,3 +91,4 @@ end

# compat
forward_inference(model, args...) = model(args...)
forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(model(inputs, clear_cache = clear_cache), targets, loss_mask = loss_mask)
16 changes: 7 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Accessors

encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1
decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...)

Expand Down Expand Up @@ -162,37 +160,37 @@ function load_llama3_from_safetensors(
if !isempty(add_lora_to)
if :Q in add_lora_to
for layer in model.layers
@reset layer.attention.wq = LoRADense(layer.attention.wq, lora_dim)
layer.attention.wq = LoRADense(layer.attention.wq, lora_dim)
end
end
if :K in add_lora_to
for layer in model.layers
@reset layer.attention.wk = LoRADense(layer.attention.wk, lora_dim)
layer.attention.wk = LoRADense(layer.attention.wk, lora_dim)
end
end
if :V in add_lora_to
for layer in model.layers
@reset layer.attention.wv = LoRADense(layer.attention.wv, lora_dim)
layer.attention.wv = LoRADense(layer.attention.wv, lora_dim)
end
end
if :O in add_lora_to
for layer in model.layers
@reset layer.attention.wo = LoRADense(layer.attention.wo, lora_dim)
layer.attention.wo = LoRADense(layer.attention.wo, lora_dim)
end
end
if :w1 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim)
layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim)
end
end
if :w2 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim)
layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim)
end
end
if :w3 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim)
layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim)
end
end
end
Expand Down

0 comments on commit 355af14

Please sign in to comment.