diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 693c426..930cabb 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -28,6 +28,7 @@ export rerope_cache! include("model.jl") export forward_loss export forward_inference +export loss include("sampling.jl") export top_pk_sampler diff --git a/src/model.jl b/src/model.jl index aaf8d62..ec8f96a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -15,16 +15,60 @@ function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractF end end -function (model::Transformer)(tokens::AbstractArray{Int}) + +function masked_agg(ce, mask) + if mask !== nothing + ce = ce .* mask + end + return sum(ce)/sum(mask) +end + +#Hoping this will wind up in Zygote.jl +""" + eager_update!(state, model, update!) + +Updates params during the backward pass, saving memory. + +f(model, xs...) = model(xs...) +h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...) +""" +function eager_update!(state, model, update!) + function update_hook(dmodel) + update!(state, model, dmodel) + return nothing + end + return Flux.Zygote.hook(update_hook, model) +end + + +wrap(model, xs...) = model(xs...) +function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false) + if clear_cache + Flux.ChainRulesCore.ignore_derivatives() do + Jjama3.clear_cache!(model) + end + end h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[model.pos+1:model.pos+size(tokens, 1)] if size(h, 2) == 1 - mask = create_mask(h) + mask = Jjama3.create_mask(h) else - mask = create_mask(h; precached_size = model.pos) + mask = Jjama3.create_mask(h; precached_size = model.pos) end - for layer in model.layers - h = layer(h, model.pos, rope, mask) + for i in 1:length(model.layers) + if !isnothing(opt_state) + if checkpointed + h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) + else + h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask) + end + else + if checkpointed + h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask) + else + h = model.layers[i](h, model.pos, rope, mask) + end + end end h = model.norm(h) output = model.output(h) @@ -32,22 +76,10 @@ function (model::Transformer)(tokens::AbstractArray{Int}) return output end -function masked_agg(ce, mask) - if mask !== nothing - ce = ce .* mask - end - return sum(ce)/sum(mask) -end +(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed) -function forward_loss(model::Transformer, inputs::AbstractArray, - targets::AbstractArray; clear_cache = true, loss_mask = nothing) - if clear_cache - Flux.ChainRulesCore.ignore_derivatives() do - clear_cache!(model) - end - end - logits = model(inputs) - vocab_size = size(model.tok_embeddings.weight, 2) +function loss(logits, targets::AbstractArray; loss_mask = nothing) + vocab_size = size(logits,1) gt = Flux.onehotbatch(targets, 1:vocab_size) if loss_mask !== nothing loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask)) @@ -59,3 +91,4 @@ end # compat forward_inference(model, args...) = model(args...) +forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(forward(model, inputs; clear_cache = clear_cache), targets; loss_mask = loss_mask)