Skip to content

Commit

Permalink
Merge pull request #19 from MurrellGroup/cache-tricks
Browse files Browse the repository at this point in the history
Cache tricks.
  • Loading branch information
murrellb authored Dec 24, 2024
2 parents c199388 + e9fc429 commit cb47d34
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ export TransformerBlock
export Transformer
export unrope
export rerope_cache!
export scrape_cache
export append_cache!

include("model.jl")
export forward_loss
Expand Down
20 changes: 20 additions & 0 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,24 @@ function rerope_cache!(model, newstart, rope_theta; range = 1:model.pos)
unroped = unrope(oldrope, l.attention.cache.cache_k[:,range,:,:])
l.attention.cache.cache_k[:,range,:,:] .= newrope(unroped)
end
end

function scrape_cache(model::Transformer)
cache = (k = [], v = [])
for l in model.layers
push!(cache.k, copy(l.attention.cache.cache_k[:,1:model.pos,:,:]))
push!(cache.v, copy(l.attention.cache.cache_v[:,1:model.pos,:,:]))
end
return cache
end

function append_cache!(model, cache)
if model.pos + size(cache.k[1], 2) > size(model.layers[1].attention.cache.cache_k, 2)
extend_cache!(model, model.pos + size(cache.k[1], 2))
end
for (i, l) in enumerate(model.layers)
l.attention.cache.cache_k[:, model.pos+1:model.pos+size(cache.k[i], 2), :, :] .= cache.k[i]
l.attention.cache.cache_v[:, model.pos+1:model.pos+size(cache.v[i], 2), :, :] .= cache.v[i]
end
model.pos = model.pos + size(cache.k[1], 2)
end

0 comments on commit cb47d34

Please sign in to comment.