diff --git a/.gitignore b/.gitignore index 5decb6a..0a51603 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /docs/Manifest.toml /docs/build/ /Manifest.toml +.CondaPkg \ No newline at end of file diff --git a/Project.toml b/Project.toml index f61ad42..d169abf 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,10 @@ version = "1.0.0-DEV" BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" +LowRankLayers = "b66182ab-a85c-43b0-99bd-d85cc47c5e50" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -15,6 +18,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +[sources] +HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} +LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"} +LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} + [extensions] MetalExt = "Metal" @@ -22,6 +30,7 @@ MetalExt = "Metal" BytePairEncoding = "0.5" Distributions = "0.25" Flux = "0.14" +LowRankLayers = "1.0.0" Metal = "1" NNlib = "0.9" SafeTensors = "1" diff --git a/README.md b/README.md index a6ccff1..e547dc3 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,136 @@ # Jjama3 - Hackable Llama3.1 and Llama3.2 (text) in Julia -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/) [![Build Status](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/Jjama3.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Jjama3.jl) -# Installation +## Installation + +We've split this into a few (unregistered) packages, so you'll need to add them all, and you need JSON3 for loading the configs: ```julia +] add JSON3 +] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl +] add https://github.com/MurrellGroup/LowRankLayers.jl +] add https://github.com/MurrellGroup/LogitSamplers.jl ] add https://github.com/MurrellGroup/Jjama3.jl ``` -# Quickstart +## Quickstart -Download a Llama3 model `config.json` and safetensor weights from Huggingface. Eg. [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct). You might need access permissions for this. Note: Huggingface use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original weights you'll need to permute them. +Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible. ```julia -config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String)); -model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config); -tkn = llama3_tokenizer(); -prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn); -ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn); +using JSON3, Jjama3 + +config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) +model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") + +prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.") +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = encode(tkn, "<|im_end|>")[end]); ``` -# Capability +## Capability - Seems to generate reasonable text from Llama3.1 and Llama3.2 models, loaded from Huggingface safetensors. - Sampling accelerated with KV caching, with argmax and top-p sampling supported. - Gradients seem to work on CPU, using Flux and Zygote. Untested on GPU. - Sampling (and forward passes) work with CUDA, where everything is much faster. Gradients untested. -- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is much slower with Metal than with CPU. +- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is slower with Metal than with CPU. + + +## Samplers + +The transformer emits "logits" which control the probability of the next token. A sampler takes these logits and converts them into a probability distribution over the vocabulary, and then samples from this distribution. There are [a few samplers available](https://github.com/MurrellGroup/LogitSamplers.jl), including argmax, top-p, top-k, min-p, and top-nσ. These can substantially affect the output of the model. + +```julia +prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); + +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = encode(tkn, "<|im_end|>")[end], + sampler = top_nσ_sampler()); +``` + +## Structured Sampling + +You can pass in a custom sampler that places additional constraints on the sampling process. As an example, `structured_choice` is a sampler that always selects from a set of predefined options: + +```julia +question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" +choices = ["Prior", + "Likelihood", + "Marginal Likelihood", + "Evidence", + "Posterior"] + +vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)] +eos = encode(tkn, "<|im_end|>")[end] + +sysprompt = "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible." +prompt = smollm2_instruct_prompt(tkn, sysprompt, question) + +generate(model, prompt, + max_new_tokens=100, + tokenizer_for_printing=tkn, + end_token = eos, + sampler = structured_choice(choices, vocab, eos)); +``` + +This strategy can be extended to force the model outputs to follow specific formats. + +## Finetuning + +Often we want to adjust model parameters to better fit our specific use case, by further training the model on a new dataset. This can be done on all the model weights, but we also provide low-rank (via LoRA) finetuning. + +```julia +using Jjama3, JSON3, Flux + +config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") +eos = encode(tkn, "<|im_end|>")[end] + +#Add LoRA to Q and V matrices when loading the model +model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, + add_lora_to = [:Q, :V], lora_dim = 64) + +#See how the model answers before finetuning +prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?"); +generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = eos); + +#Set up a single, very silly, training example to finetune on +ugh = "Ugh, bruh, what a stupid question.<|im_end|>" +trainsample = decode(tkn, prompt, skip_special_tokens = false) * ugh; +train_toks = encode(tkn, trainsample); + +#Set up the optimizer +opt_state = Flux.setup(AdamW(0.001f0), model); + +#Train for 5 steps, monitoring the model's output as it tunes +for i in 1:5 + grads = Flux.gradient(model) do m + forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:]) + end + Flux.update!(opt_state, model, grads[1]) + println(i) + generate(model, prompt, + max_new_tokens=50, + tokenizer_for_printing=tkn, + end_token = eos) + println() +end + +#Ask the model an unrelated question to see how stupid we've made the model. Try this a few times. +prompt = smollm2_assistant_prompt(tkn, "Explain how tides work?"); +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = eos, + sampler = top_nσ_sampler()); +``` + diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 8659a47..5383601 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -3,7 +3,7 @@ module MetalExt #Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. #It seems that each single Metal call has some constant overhead that kills it. -using Metal, Jjama3.NNlib +using Metal, NNlib function NNlib.batched_mul(a::MtlArray, b::MtlArray) a_shape = size(a) @@ -15,4 +15,13 @@ function NNlib.batched_mul(a::MtlArray, b::MtlArray) return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...) end +function NNlib.PermutedDimsArray(a::MtlArray, perm) + return permutedims(a, perm) +end + +function NNlib.batched_transpose(a::MtlArray) + dims = size(a) + return permutedims(a, (2,1,3:length(dims)...)) +end + end diff --git a/src/Jjama3.jl b/src/Jjama3.jl index ac0cf52..f0db5b6 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,11 +1,44 @@ module Jjama3 -using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib +using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib +using LogitSamplers, LowRankLayers +import HuggingFaceTokenizers + +const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained +const tokenizer_from_file = HuggingFaceTokenizers.from_file +const Tokenizer = HuggingFaceTokenizers.Tokenizer + +const top_pk_sampler = LogitSamplers.top_pk_sampler +const argmax_sampler = LogitSamplers.argmax_sampler +const min_p_sampler = LogitSamplers.min_p_sampler +const top_nσ_sampler = LogitSamplers.top_nσ_sampler + + + +include("layers.jl") include("model.jl") include("utils.jl") include("sampling.jl") -export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler +export load_llama321B_from_safetensors, + load_llama3_from_safetensors, + generate, + forward_loss, + forward_inference, + top_pk_sampler, + argmax_sampler, + top_nσ_sampler, + min_p_sampler, + tokenizer_from_repo, + tokenizer_from_file, + encode, + decode, + Tokenizer, + llama3_instruct_prompt, + llama3_assistant_prompt, + smollm2_instruct_prompt, + smollm2_assistant_prompt, + structured_choice end diff --git a/src/layers.jl b/src/layers.jl new file mode 100644 index 0000000..130d4b5 --- /dev/null +++ b/src/layers.jl @@ -0,0 +1,191 @@ +struct KVCache{T} + cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) + cache_v::AbstractArray{T, 4} +end + +function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device + KVCache(cache_k, cache_v) +end + +struct FeedForward + w1::Union{Dense, LoRADense} + w2::Union{Dense, LoRADense} + w3::Union{Dense, LoRADense} +end + +function FeedForward(dim::Int, ff_hidden_dim::Int) + FeedForward( + Dense(dim => ff_hidden_dim, bias=false), + Dense(ff_hidden_dim => dim, bias=false), + Dense(dim => ff_hidden_dim, bias=false) + ) +end + +function (ff::FeedForward)(x) + return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) +end + +Flux.@layer :expand FeedForward + +struct RMSNorm{T} + weight::AbstractVector{T} + eps::T +end + +function RMSNorm(dim::Int; eps::T=1f-5) where T + RMSNorm{T}(ones(T, dim), eps) +end + +function (norm::RMSNorm)(x) + rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) + return x .* (norm.weight ./ rms) +end + +Flux.@layer RMSNorm + +mutable struct Attention + wq::Union{Dense, LoRADense} + wk::Union{Dense, LoRADense} + wv::Union{Dense, LoRADense} + wo::Union{Dense, LoRADense} + n_heads::Int + n_kv_heads::Int + head_dim::Int + n_rep::Int + cache::Union{Nothing, KVCache} +end + +function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads) + head_dim = dim ÷ n_heads + n_rep = n_heads ÷ n_kv_heads + Attention( + Dense(dim => n_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(n_heads * head_dim => dim, bias=false), + n_heads, + n_kv_heads, + head_dim, + n_rep, + nothing + ) +end + +function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T + dim, seqlen, batch = size(x) + + xq = attn.wq(x) + xk = attn.wk(x) + xv = attn.wv(x) + + xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) + xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + + #Lazy permute dims. Need to test CUDA. + xq = PermutedDimsArray(xq, (1,3,2,4)) + xk = PermutedDimsArray(xk, (1,3,2,4)) + xv = PermutedDimsArray(xv, (1,3,2,4)) + + xq_rope = apply_rotary_emb(xq, freqs_cis) + xk_rope = apply_rotary_emb(xk, freqs_cis) + + if !isnothing(attn.cache) + xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) + end + + xk_rope = repeat_kv(xk_rope, attn.n_rep) + xv = repeat_kv(xv, attn.n_rep) + + xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) + xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) + xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) + + #= + scores = batched_mul( + permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) + #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) + xk_for_attn # (head_dim, seqlen, batch*heads) + ) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + scores = scores .+ mask + end + sm_scores = softmax(scores; dims=2) + output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) + e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) + p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) + =# + + scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + scores = scores .+ mask + end + sm_scores = softmax(scores; dims=1) + output = batched_mul(xv_for_attn, sm_scores) + e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) + p_output = permutedims(e_output, (1,3,2,4)) + + r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) + proj = attn.wo(r_output) + return proj +end + +Flux.@layer :expand Attention trainable=(wq, wv) + +struct TransformerBlock + attention::Attention + feed_forward::FeedForward + attention_norm::RMSNorm + ffn_norm::RMSNorm +end + +function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; + norm_eps=1f-5) + TransformerBlock( + Attention(dim, n_heads, n_kv_heads), + FeedForward(dim, ff_hidden_dim), + RMSNorm(dim, eps=norm_eps), + RMSNorm(dim, eps=norm_eps) + ) +end + +function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) + h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) + out = h + block.feed_forward(block.ffn_norm(h)) + return out +end + +Flux.@layer TransformerBlock trainable=(attention, ) + +struct Transformer{T} + tok_embeddings::Flux.Embedding + layers::AbstractVector{TransformerBlock} + norm::RMSNorm{T} + output::Dense + freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} +end + +function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, + n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; + norm_eps::T=1f-5, + rope_theta::T=500000f0, + use_scaled_rope=false, + scale_factor=8) where T + + tok_embeddings = Flux.Embedding(vocab_size => dim) + layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers] + norm = RMSNorm(dim, eps=norm_eps) + output = Dense(dim => vocab_size, bias=false) + freqs_cis = precompute_freqs_cis( + dim ÷ n_heads, + max_seq_len * 2; + theta=rope_theta, + use_scaled=use_scaled_rope, + scale_factor=scale_factor + ) + Transformer(tok_embeddings, layers, norm, output, freqs_cis) +end + +Flux.@layer :expand Transformer trainable=(layers, ) diff --git a/src/model.jl b/src/model.jl index 1078808..9fdb3c4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -43,66 +43,17 @@ end #Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #Use this one if you're using the Hugging Face weights. function apply_rotary_emb(x, freqs_cis) - # x is (head_dim, seq_len, n_heads, batch) head_dim, seq_len, n_heads, batch = size(x) - x1 = x[1:head_dim÷2, :, :, :] - x2 = x[head_dim÷2+1:end, :, :, :] + x1 = @view x[1:head_dim÷2, :, :, :] + x2 = @view x[head_dim÷2+1:end, :, :, :] cos, sin = freqs_cis - out = vcat( - x1 .* cos[:,1:seq_len,:] .- x2 .* sin[:,1:seq_len,:], - x2 .* cos[:,1:seq_len,:] .+ x1 .* sin[:,1:seq_len,:] + out = vcat( + x1 .* cos .- x2 .* sin, + x2 .* cos .+ x1 .* sin ) return out end - -struct FeedForward - w1::Dense - w2::Dense - w3::Dense -end - -function FeedForward(dim::Int, ff_hidden_dim::Int) - FeedForward( - Dense(dim => ff_hidden_dim, bias=false), - Dense(ff_hidden_dim => dim, bias=false), - Dense(dim => ff_hidden_dim, bias=false) - ) -end - -function (ff::FeedForward)(x) - return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -end - -Flux.@layer :expand FeedForward - -struct RMSNorm{T} - weight::AbstractVector{T} - eps::T -end - -function RMSNorm(dim::Int; eps::T=1f-5) where T - RMSNorm{T}(ones(T, dim), eps) -end - -function (norm::RMSNorm)(x) - rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) - return x .* (norm.weight ./ rms) -end - -Flux.@layer RMSNorm - -struct KVCache{T} - cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) - cache_v::AbstractArray{T, 4} -end - -function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) - cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - KVCache(cache_k, cache_v) -end - function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) seqlen = size(xk, 2) cache.cache_k[:, (start_pos+1):(start_pos+seqlen), :, :] .= xk @@ -112,163 +63,12 @@ function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv:: end function repeat_kv(x::AbstractArray, n_rep::Int) - # x is (head_dim, seq_len, n_kv_heads, batch) - # output should be (head_dim, seq_len, n_rep * n_kv_heads, batch) if n_rep == 1 return x end - head_dim, seq_len, n_kv_heads, batch = size(x) - x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch)) - x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) - return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch)) -end - -mutable struct Attention - wq::Dense - wk::Dense - wv::Dense - wo::Dense - n_heads::Int - n_kv_heads::Int - head_dim::Int - n_rep::Int - cache::Union{Nothing, KVCache} -end - -function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads) - head_dim = dim ÷ n_heads - n_rep = n_heads ÷ n_kv_heads - Attention( - Dense(dim => n_heads * head_dim, bias=false), - Dense(dim => n_kv_heads * head_dim, bias=false), - Dense(dim => n_kv_heads * head_dim, bias=false), - Dense(n_heads * head_dim => dim, bias=false), - n_heads, - n_kv_heads, - head_dim, - n_rep, - nothing - ) -end - -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T - dim, seqlen, batch = size(x) - - # Project Q,K,V - xq = attn.wq(x) - xk = attn.wk(x) - xv = attn.wv(x) - - #Reshaping dim: 8, len: 3, batch: 2 - # to: 2, 3, 4, 2 - # RoPE input needs (head_dim, len, n_heads, batch) - - # Reshape to separate heads - xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) - xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - - #Some GPUs don't like PermutedDimsArray - #xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster... - #xk = PermutedDimsArray(xk, (1,3,2,4)) - #xv = PermutedDimsArray(xv, (1,3,2,4)) - xq = permutedims(xq, (1,3,2,4)) - xk = permutedims(xk, (1,3,2,4)) - xv = permutedims(xv, (1,3,2,4)) - - xq_rope = apply_rotary_emb(xq, freqs_cis) - xk_rope = apply_rotary_emb(xk, freqs_cis) - - if !isnothing(attn.cache) - xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) - end - - # Apply GQA via repeat_kv - xk_rope = repeat_kv(xk_rope, attn.n_rep) - xv = repeat_kv(xv, attn.n_rep) - - xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) - xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) - xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - - scores = batched_mul( - 0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) - 0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads) - ) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end - #len: 3, len: 3, headsxbatch: 8 - sm_scores = softmax(scores; dims=2) # Need to get this over dim 1 for efficiency! - #len: 3, head_dim: 2, headsxbatch: 8 - output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) - # Reshape back to separate batch and heads - e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) - p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) - r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) - proj = attn.wo(r_output) - return proj -end - -Flux.@layer :expand Attention - -struct TransformerBlock - attention::Attention - feed_forward::FeedForward - attention_norm::RMSNorm - ffn_norm::RMSNorm -end - -function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5) - TransformerBlock( - Attention(dim, n_heads, n_kv_heads), - FeedForward(dim, ff_hidden_dim), - RMSNorm(dim, eps=norm_eps), - RMSNorm(dim, eps=norm_eps) - ) -end - -function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) - h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) - out = h + block.feed_forward(block.ffn_norm(h)) - return out -end - -Flux.@layer TransformerBlock - -struct Transformer{T} - tok_embeddings::Flux.Embedding - layers::AbstractVector{TransformerBlock} - norm::RMSNorm{T} - output::Dense - freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} -end - -function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, - n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; - norm_eps::T=1f-5, - rope_theta::T=500000f0, - use_scaled_rope=false, - scale_factor=8) where T - - tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers] - norm = RMSNorm(dim, eps=norm_eps) - output = Dense(dim => vocab_size, bias=false) - freqs_cis = precompute_freqs_cis( - dim ÷ n_heads, - max_seq_len * 2; - theta=rope_theta, - use_scaled=use_scaled_rope, - scale_factor=scale_factor - ) - Transformer(tok_embeddings, layers, norm, output, freqs_cis) + return repeat(x, 1, n_rep, 1, 1) end -Flux.@layer :expand Transformer trainable=(layers, norm) - - function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, start_pos::Int) where T seqlen = size(tokens, 1) # tokens expected as (seq_len, batch) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) @@ -277,17 +77,8 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) - mask = nothing - if seqlen > 1 - #mask = fill(T(-Inf), (seqlen, seqlen)) - mask = similar(h, seqlen, seqlen) - mask .= T(-Inf) - mask = triu(mask, 1) - # Add zeros for cached sequence - if start_pos > 0 - mask = hcat(zeros(T, seqlen, start_pos), mask) - end - end + + mask = create_mask(h) for layer in model.layers h = layer(h, start_pos, freqs_cis, mask) end @@ -297,12 +88,15 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st end function create_mask(h::AbstractArray) - embeddim, seqlen, batch = size(h) - mask = similar(h, seqlen, seqlen) - T = eltype(h) - mask .= T(-Inf) - mask = triu(mask, 1) - return mask + Flux.Zygote.ignore() do + embeddim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + T = eltype(h) + mask .= T(-Inf) + #mask = triu(mask, 1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup + return mask + end end function forward_loss(model::Transformer{T}, inputs::AbstractArray, diff --git a/src/sampling.jl b/src/sampling.jl index 7473191..ee51d3d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,42 +1,20 @@ - -function argmax_sampler(logits::AbstractVector; device = identity) - return argmax(device(logits)) -end - -argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device) - -function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity) - probs = device(Jjama3.softmax(logits)) - perm = partialsortperm(probs, 1:k, rev=true) - sorted_probs = probs[perm] - cumsum_probs = cumsum(sorted_probs) - if cumsum_probs[1] > p - return perm[1] - else - cutoff = findlast(cumsum_probs .< p) - return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff])) - end -end - -top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) - # This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence # to the sampling of each token. But when I try and fix it, the model gets slightly dumber. # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. """ - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. tkn = llama3_tokenizer() - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) """ function generate(model::Transformer{T}, initial_tokens::AbstractArray{IntT}; max_new_tokens=100, sampler::Function=argmax_sampler, - encoder_for_printing = nothing, + tokenizer_for_printing = nothing, end_token = 128010, device = identity) where {T, IntT} @@ -77,8 +55,8 @@ function generate(model::Transformer{T}, next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token - if !isnothing(encoder_for_printing) - print(encoder_for_printing.decode([next_token])) + if !isnothing(tokenizer_for_printing) + print(decode(tokenizer_for_printing, [next_token])) end if next_token == end_token break diff --git a/src/utils.jl b/src/utils.jl index 29de607..cd1df4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,65 +1,33 @@ -""" - tkn = llama3_tokenizer() +encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 +decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) -Load the tokenizer for Llama3. This seems to work, but I have not checked if there are some different edge-cases, or missing tokens relative to the original tokenizer (besides the special tokens we hackily include). - tkn = llama3_tokenizer() - tkn.encode("What is the capital of France?") - tkn.decode([10, 2, 5, 99]) -""" -llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base") +#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ +function llama3_instruct_prompt(tokenizer,system_prompt, user_prompt) + str = """<|start_header_id|>system<|end_header_id|> +$system_prompt +<|eot_id|><|start_header_id|>user<|end_header_id|> + +$(user_prompt)<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" + return encode(tokenizer, str) +end """ generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt. - tkn = llama3_tokenizer() - prompt = assistant_prompt("What is the capital of France?", tkn) + prompt = assistant_prompt(tkn, "What is the capital of France?") generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) """ -assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn); - - -#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ -""" - generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) - -Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. +llama3_assistant_prompt(tokenizer, prompt) = llama3_instruct_prompt(tokenizer,"\nYou are a helpful assistant\n", prompt); - tkn = llama3_tokenizer() - prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn) - generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) -""" -function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) - prompt = [128001, 128007] #begin_of_text, start_header_id - prompt = vcat(prompt, tokenizer.encode("system")) - push!(prompt, 128008) #end_header_id - prompt = vcat(prompt, tokenizer.encode(sys_prompt)) - prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id - prompt = vcat(prompt, tokenizer.encode("user")) - push!(prompt, 128008) #end_header_id - prompt = vcat(prompt, tokenizer.encode("\n")) - prompt = vcat(prompt, tokenizer.encode(user_prompt)) - prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id - prompt = vcat(prompt, tokenizer.encode("assistant")) - push!(prompt, 128008) #end_header_id - return prompt +function smollm2_instruct_prompt(tokenizer, system_prompt, user_prompt) + str = """<|im_start|>system\n$(system_prompt)<|im_end|>\n<|im_start|>user\n$(user_prompt)<|im_end|>\n<|im_start|>assistant\n""" + return encode(tokenizer, str) end -#These have already been incremented by 1 to account for Julia's 1-indexing -special_tokens = Dict( - "<|begin_of_text|>" => 128001, - "<|end_of_text|>" => 128002, - "<|start_header_id|>" => 128007, - "<|end_header_id|>" => 128008, - "<|eot_id|>" => 128010, - "<|finetune_right_pad_id|>" => 128005, - "<|python_tag|>" => 128011 -) - -#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ] - +smollm2_assistant_prompt(tokenizer, prompt) = smollm2_instruct_prompt(tokenizer, "You are a helpful AI assistant named SmolLM, trained by Hugging Face", prompt); """ model = load_llama3_from_safetensors(model_weight_paths, config) @@ -73,14 +41,20 @@ so if you're loading weights from a different source, you might get very poor mo model_weight_paths = ["Llama3_2_1B_instruct/model.safetensors"] #Can be an array of paths if the model is split across multiple files model = load_llama3_from_safetensors(model_weight_paths, config) """ -function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32) +function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32, add_lora_to = Symbol[], lora_dim = 0) config = Dict(config) #Just in case the user passed eg. a JSON3.Object - @assert config[:rope_scaling][:rope_type] == "llama3" - @assert config[:rope_scaling][:low_freq_factor] == 1 - @assert config[:rope_scaling][:high_freq_factor] == 4 - @assert config[:rope_scaling][:original_max_position_embeddings] == 8192 + #@assert config[:rope_scaling][:rope_type] == "llama3" + #@assert config[:rope_scaling][:low_freq_factor] == 1 + #@assert config[:rope_scaling][:high_freq_factor] == 4 + #@assert config[:rope_scaling][:original_max_position_embeddings] == 8192 # Create model with config parameters from the JSON + scale_factor = 1f0 + if haskey(config, :rope_scaling) + if !isnothing(config[:rope_scaling]) + scale_factor = config[:rope_scaling][:factor] + end + end model = Transformer( config[:vocab_size], # vocab_size config[:hidden_size], # dim (hidden_size) @@ -92,7 +66,7 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 norm_eps=T(config[:rms_norm_eps]), # rms_norm_eps rope_theta=T(config[:rope_theta]), # rope_theta use_scaled_rope=true, # Using scaled RoPE based on the config - scale_factor=config[:rope_scaling][:factor] # scale_factor + scale_factor=scale_factor # scale_factor ) for path in paths # Process one file at a time @@ -152,9 +126,92 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 weights = nothing GC.gc() end - + + if !isempty(add_lora_to) + #Then load in the current layers: + if :Q in add_lora_to + for layer in model.layers + layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) + end + end + if :K in add_lora_to + for layer in model.layers + layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) + end + end + if :V in add_lora_to + for layer in model.layers + layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) + end + end + if :O in add_lora_to + for layer in model.layers + layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) + end + end + if :w1 in add_lora_to + for layer in model.layers + layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) + end + end + if :w2 in add_lora_to + for layer in model.layers + layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) + end + end + if :w3 in add_lora_to + for layer in model.layers + layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) + end + end + end return model end -load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T) +load_llama3_from_safetensors(path::String, config; T = Float32, kwargs...) = load_llama3_from_safetensors([path], config; T = T, kwargs...) + +""" + sampler = structured_choice(choices, vocab::Vector{String}, end_token::Int; sampler = logits -> argmax_sampler(logits)) + +Return a function that can be passed into generate as a sampler, which will sample from the given choices. Handles the case where the choices are made up of multiple tokens. +`vocab` is an array of the tokens as strings, in their order in the tokenizer. `sampler` is a function that takes the logits (here including those masked with -Inf) and returns a sample from them. Defaults to argmax. + +Example: +```julia +config = JSON3.read(read("SmolLM2-1.7B-Instruct/config.json", String)) +model = load_llama3_from_safetensors("SmolLM2-1.7B-Instruct/model.safetensors", config) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-1.7B-Instruct/tokenizer.json") + +question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" +choices = ["Prior", "Likelihood", "Marginal Likelihood", "Evidence", "Posterior"] + +vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:49152] +eos = encode(tkn, "<|im_end|>")[end] +prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question) +generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos)); +``` +""" +function structured_choice(choices::Vector{String}, vocab::Vector{String}, end_token::Int; sampler = logits -> argmax_sampler(logits), device = identity) + remaining_choices = copy(choices) + function choice_sampler(logits) + logits = device(logits) + if length(remaining_choices) == 0 || maximum(length.(remaining_choices)) == 0 + return end_token + end + mask = zeros(Bool, length(vocab)) + for i in 1:length(vocab) + for choice in remaining_choices + if startswith(choice, vocab[i]) + mask[i] = true + end + end + end + logits[.!mask] .= -Inf + next_token = sampler(logits) + next_token_str = vocab[next_token] + remaining_choices = [choice[length(next_token_str)+1:end] for choice in remaining_choices if startswith(choice, next_token_str)] + return next_token + end + return choice_sampler +end