Skip to content

Commit

Permalink
Reactant compatibility (#15)
Browse files Browse the repository at this point in the history
* Reactant compatibility

* Update Project.toml

* Refactor

* Refactor and fixes

* rm ReactantCore

* Fixes, conditional caching
  • Loading branch information
AntonOresten authored Dec 4, 2024
1 parent 0b50d41 commit 1f44b58
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 356 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Jjama3"
uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592"
authors = ["murrellb <[email protected]> and contributors"]
version = "1.0.0-DEV"
version = "1.1.0-DEV"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
Expand All @@ -19,13 +20,13 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}
LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"}
LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"}

[extensions]
MetalExt = "Metal"

[compat]
Accessors = "0.1.38"
Distributions = "0.25"
Flux = "0.14"
LogitSamplers = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ generate(model, prompt,
- RoPE scaling (for exceeding the model's max training-time context length) is implemented, but likely incorrect with KV cache. Be careful if you're using with really long sequences.
- Imported models are trainable (with Flux), including with low-rank (ie. LoRA) finetuning.
- Sampling, training, etc compatible with CUDA, where everything is much faster.
- Metal acceleration for forward_inference, forward_loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU.
- Metal acceleration for forward inference, forward loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU.


## Samplers
Expand Down
6 changes: 4 additions & 2 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module MetalExt

#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling.
#It seems that each single Metal call has some constant overhead that kills it.
# See https://github.com/FluxML/NNlib.jl/pull/614

# Note: Metal speeds things up a little for forward inference and forward_loss calls, but is VERY slow for sampling.
# It seems that each single Metal call has some constant overhead that kills it.

using Metal, NNlib

Expand Down
69 changes: 39 additions & 30 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,53 @@
module Jjama3

using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
using LogitSamplers, LowRankLayers
import HuggingFaceTokenizers
using Flux
using SafeTensors
using Distributions
using LinearAlgebra
using StatsBase
using NNlib
using LogitSamplers
using LowRankLayers

using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer

const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained
const tokenizer_from_file = HuggingFaceTokenizers.from_file
const Tokenizer = HuggingFaceTokenizers.Tokenizer

const top_pk_sampler = LogitSamplers.top_pk_sampler
const argmax_sampler = LogitSamplers.argmax_sampler
const min_p_sampler = LogitSamplers.min_p_sampler
const top_nσ_sampler = LogitSamplers.top_nσ_sampler


include("cache.jl")
export KVCache

include("layers.jl")
export FeedForward
export RMSNorm
export RoPE
export Attention
export TransformerBlock
export Transformer

include("model.jl")
include("utils.jl")
export forward_loss
export forward_inference

include("sampling.jl")
export top_pk_sampler
export argmax_sampler
export top_nσ_sampler
export min_p_sampler
export generate
export tokenizer_from_repo
export tokenizer_from_file
export Tokenizer

export load_llama321B_from_safetensors,
load_llama3_from_safetensors,
generate,
forward_loss,
forward_inference,
top_pk_sampler,
argmax_sampler,
top_nσ_sampler,
min_p_sampler,
tokenizer_from_repo,
tokenizer_from_file,
encode,
decode,
Tokenizer,
llama3_instruct_prompt,
llama3_assistant_prompt,
smollm2_instruct_prompt,
smollm2_assistant_prompt,
structured_choice
include("utils.jl")
export encode
export decode
export load_llama321B_from_safetensors
export load_llama3_from_safetensors
export llama3_instruct_prompt
export llama3_assistant_prompt
export smollm2_instruct_prompt
export smollm2_assistant_prompt
export structured_choice

end
37 changes: 37 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
mutable struct KVCache{T,A<:AbstractArray{T,4}}
cache_k::A
cache_v::A
end

Flux.@layer KVCache

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
n_kv_heads(cache::KVCache) = size(cache.cache_k, 3)
batch_size(cache::KVCache) = size(cache.cache_k, 4)

function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1)
cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
return KVCache(cache_k, cache_v)
end

function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache))
cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
end

clear!(cache::KVCache) = config!(cache, seq_length=0)

function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray)
if iszero(seq_length(cache))
println("fuck")
return xk, xv
else
seqlen = size(xk, 2)
cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk
cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv
return cache.cache_k[:, 1:start_pos+seqlen, :, :],
cache.cache_v[:, 1:start_pos+seqlen, :, :]
end
end
Loading

0 comments on commit 1f44b58

Please sign in to comment.