Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager updates, and loss refactor #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export rerope_cache!
include("model.jl")
export forward_loss
export forward_inference
export loss

include("sampling.jl")
export top_pk_sampler
Expand Down
73 changes: 53 additions & 20 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,71 @@
end
end

function (model::Transformer)(tokens::AbstractArray{Int})

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask

Check warning on line 21 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L19-L21

Added lines #L19 - L21 were not covered by tests
end
return sum(ce)/sum(mask)

Check warning on line 23 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L23

Added line #L23 was not covered by tests
end

#Hoping this will wind up in Zygote.jl
"""
eager_update!(state, model, update!)

Updates params during the backward pass, saving memory.

f(model, xs...) = model(xs...)
h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
"""
function eager_update!(state, model, update!)
function update_hook(dmodel)
update!(state, model, dmodel)
return nothing

Check warning on line 38 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L35-L38

Added lines #L35 - L38 were not covered by tests
end
return Flux.Zygote.hook(update_hook, model)

Check warning on line 40 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L40

Added line #L40 was not covered by tests
end


wrap(model, xs...) = model(xs...)

Check warning on line 44 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L44

Added line #L44 was not covered by tests
function (model::Transformer)(tokens::AbstractArray{Int}, opt_state; clear_cache = false, checkpointed = false)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
Jjama3.clear_cache!(model)

Check warning on line 48 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
end
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
mask = create_mask(h)
mask = Jjama3.create_mask(h)
else
mask = create_mask(h; precached_size = model.pos)
mask = Jjama3.create_mask(h; precached_size = model.pos)

Check warning on line 56 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L56

Added line #L56 was not covered by tests
end
for layer in model.layers
h = layer(h, model.pos, rope, mask)
for i in 1:length(model.layers)
if !isnothing(opt_state)
if checkpointed
h = Flux.Zygote.checkpointed(wrap, eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)

Check warning on line 61 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
else
h = wrap(eager_update!(opt_state.layers[i], model.layers[i], Optimisers.update!), h, model.pos, rope, mask)

Check warning on line 63 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L63

Added line #L63 was not covered by tests
end
else
if checkpointed
h = Flux.Zygote.checkpointed(wrap, model.layers[i], h, model.pos, rope, mask)

Check warning on line 67 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L67

Added line #L67 was not covered by tests
else
h = model.layers[i](h, model.pos, rope, mask)
end
end
end
h = model.norm(h)
output = model.output(h)
model.pos += size(tokens, 1)
return output
end

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask
end
return sum(ce)/sum(mask)
end
(model::Transformer)(tokens::AbstractArray{Int}; clear_cache = false, checkpointed = false) = model(tokens, nothing; clear_cache, checkpointed)

function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
clear_cache!(model)
end
end
logits = model(inputs)
vocab_size = size(model.tok_embeddings.weight, 2)
function loss(logits, targets::AbstractArray; loss_mask = nothing)
vocab_size = size(logits,1)

Check warning on line 82 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
gt = Flux.onehotbatch(targets, 1:vocab_size)
if loss_mask !== nothing
loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask))
Expand All @@ -59,3 +91,4 @@

# compat
forward_inference(model, args...) = model(args...)
forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(forward(model, inputs; clear_cache = clear_cache), targets; loss_mask = loss_mask)

Check warning on line 94 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L94

Added line #L94 was not covered by tests
Loading