From bdf3eef4c15294d8e5c060c703ba475fa4a83db0 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Thu, 14 Nov 2024 01:20:48 +0000 Subject: [PATCH 1/8] CompatHelper: add new compat entry for BytePairEncoding at version 0.5, (keep existing compat) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index b833e42..77c52e3 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" [compat] +BytePairEncoding = "0.5" julia = "1.9" [extras] From 56c645294618b371ace416dd26c145a11101dc4f Mon Sep 17 00:00:00 2001 From: murrellb Date: Thu, 14 Nov 2024 12:55:57 +0100 Subject: [PATCH 2/8] Forward pass working on Metal. Sampling slow though. --- Project.toml | 8 ++++++++ ext/MetalExt.jl | 18 ++++++++++++++++++ src/Jjama3.jl | 4 ++-- src/model.jl | 47 ++++++++++++++++++++++------------------------- src/sampling.jl | 41 ++++++++++++++++++++++++++++++++++------- src/utils.jl | 22 ++++++++++++++++++---- 6 files changed, 102 insertions(+), 38 deletions(-) create mode 100644 ext/MetalExt.jl diff --git a/Project.toml b/Project.toml index b833e42..76755b9 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,15 @@ BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[weakdeps] +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + +[extensions] +MetalExt = "Metal" [compat] julia = "1.9" diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl new file mode 100644 index 0000000..8659a47 --- /dev/null +++ b/ext/MetalExt.jl @@ -0,0 +1,18 @@ +module MetalExt + +#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. +#It seems that each single Metal call has some constant overhead that kills it. + +using Metal, Jjama3.NNlib + +function NNlib.batched_mul(a::MtlArray, b::MtlArray) + a_shape = size(a) + b_shape = size(b) + a_reshaped = reshape(a, a_shape[1], a_shape[2], :) + b_reshaped = reshape(b, b_shape[1], b_shape[2], :) + res = Metal.zeros(a_shape[1], b_shape[2], size(a_reshaped)[3]) + Metal.MPS.matmul!(res, a_reshaped,b_reshaped) + return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...) +end + +end diff --git a/src/Jjama3.jl b/src/Jjama3.jl index a59ac18..ac0cf52 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,11 +1,11 @@ module Jjama3 -using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra +using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib include("model.jl") include("utils.jl") include("sampling.jl") -export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference +export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler end diff --git a/src/model.jl b/src/model.jl index 6637282..1078808 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,9 +1,11 @@ -#Important about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 +#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 function apply_scaling(freqs::AbstractVector; scale_factor=8) + #Hard-coded - I should move these to the main model struct and grab them from the config. low_freq_factor = 1 high_freq_factor = 4 - old_context_len = 8192 # original llama3 length + old_context_len = 8192 + ### low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = similar(freqs) @@ -38,7 +40,7 @@ function precompute_freqs_cis(dim::Int, end_pos::Int; end -#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #Use this one if you're using the Hugging Face weights. function apply_rotary_emb(x, freqs_cis) # x is (head_dim, seq_len, n_heads, batch) @@ -118,7 +120,6 @@ function repeat_kv(x::AbstractArray, n_rep::Int) head_dim, seq_len, n_kv_heads, batch = size(x) x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch)) x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) - x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch)) end @@ -167,7 +168,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Can we switch to keeping this in its shape, aplpying rot emb in that shape, and only permuting one + #Some GPUs don't like PermutedDimsArray #xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster... #xk = PermutedDimsArray(xk, (1,3,2,4)) #xv = PermutedDimsArray(xv, (1,3,2,4)) @@ -175,29 +176,26 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = permutedims(xk, (1,3,2,4)) xv = permutedims(xv, (1,3,2,4)) - # Apply RoPE xq_rope = apply_rotary_emb(xq, freqs_cis) xk_rope = apply_rotary_emb(xk, freqs_cis) - # Handle KV cache + if !isnothing(attn.cache) xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) end + # Apply GQA via repeat_kv xk_rope = repeat_kv(xk_rope, attn.n_rep) xv = repeat_kv(xv, attn.n_rep) - # Reshape for attention - dummy dim is seqlength, which isn't the length of the seq when using the KV cache + xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - # Compute attention scores scores = batched_mul( 0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) 0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads) ) ./ sqrt(T(attn.head_dim)) if !isnothing(mask) - #@show typeof(scores) - #@show typeof(mask) scores = scores .+ mask end #len: 3, len: 3, headsxbatch: 8 @@ -298,28 +296,34 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st return output end +function create_mask(h::AbstractArray) + embeddim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + T = eltype(h) + mask .= T(-Inf) + mask = triu(mask, 1) + return mask +end + function forward_loss(model::Transformer{T}, inputs::AbstractArray, targets::AbstractArray; ignore_index::Int=-100, - mask = triu(fill(T(-Inf), (size(inputs, 1), size(inputs, 1))),1)) where T + mask = :auto) where T seqlen = size(inputs, 1) #(seq_len, batch) h = model.tok_embeddings(inputs) # (dim, seq_len, batch) - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) - # Forward through layers (start_pos = 0 disables KV caching) + if mask == :auto + mask = create_mask(h) + end for layer in model.layers h = layer(h, 0, freqs_cis, mask) end h = model.norm(h) logits = model.output(h) - #@show size(logits) # Need to reshape to (vocab_size, seq_len * batch) logits_2d = reshape(logits, size(logits,1), :) - #@show [argmax(logits_2d[:,i]) for i in 1:size(logits_2d,2)] - #@show size(logits_2d) targets_1d = reshape(targets, :) - #@show size(targets_1d) # Mask out ignored indices - will handle this later. # Note: this is not the autoregressive mask, but the mask for the loss function #= @@ -335,18 +339,11 @@ function forward_loss(model::Transformer{T}, inputs::AbstractArray, =# vocab_size = size(model.tok_embeddings.weight, 2) gt = Flux.onehotbatch(targets_1d, 1:vocab_size) - #@show size(gt) loss = Flux.logitcrossentropy(logits_2d, gt) - #@show Flux.logitcrossentropy(logits_2d, gt, agg = identity) return loss end - - - - - #https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #= #Use this one if you're using the original Meta weights. diff --git a/src/sampling.jl b/src/sampling.jl index 6a13e5c..9833da3 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,12 +1,39 @@ -function default_sampler(logits::AbstractVector) - return argmax(logits) +function argmax_sampler(logits::AbstractVector; device = identity) + return argmax(device(logits)) end +argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device) + +function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity) + probs = device(Jjama3.softmax(logits)) + perm = partialsortperm(probs, 1:k, rev=true) + sorted_probs = probs[perm] + cumsum_probs = cumsum(sorted_probs) + if cumsum_probs[1] > p + return perm[1] + else + cutoff = findlast(cumsum_probs .< p) + return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff])) + end +end + +top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) + +# This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence +# to the sampling of each token. But when I try and fix it, the model gets slightly dumber. +# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. +""" +Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. +Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. + + tkn = llama3_tokenizer() + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) +""" function generate(model::Transformer{T}, initial_tokens::AbstractArray{IntT}; max_new_tokens=100, - sampler::Function=default_sampler, + sampler::Function=argmax_sampler, encoder_for_printing = nothing, end_token = 128010, device = identity) where {T, IntT} @@ -28,7 +55,7 @@ function generate(model::Transformer{T}, end # Process the initial sequence if current_len > 0 - input_tokens = reshape(initial_tokens, :, 1) # (seq_len, batch=1) + input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) logits = forward_inference(model, input_tokens, 0) start_pos = current_len else @@ -38,14 +65,14 @@ function generate(model::Transformer{T}, for _ in 1:max_new_tokens # If sequence is empty or we want to process just the last token if start_pos == 0 - input_tokens = reshape([128001], :, 1) # Use start of text token if empty + input_tokens = device(reshape([128001], :, 1)) # Use start of text token if empty else - input_tokens = reshape([tokens[current_len]], :, 1) # Just the last token + input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token end # Get logits for next token logits = forward_inference(model, input_tokens, start_pos) # Sample next token (logits are size vocab × 1 × 1) - next_token = sampler(vec(logits[:, end, 1])) + next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token if !isnothing(encoder_for_printing) diff --git a/src/utils.jl b/src/utils.jl index f93046f..aec8cc1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ Format a prompt for use with Llama3.2's instruction format, with a simple "You a assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn); +#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ """ Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. @@ -17,9 +18,8 @@ Format a prompt for use with Llama3.2's instruction format, injecting the system prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn) generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) """ -function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) - #begin_of_text, start_header_id - prompt = [128001, 128007] #plus 1 because Julia is 1-indexed +function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) + prompt = [128001, 128007] #begin_of_text, start_header_id prompt = vcat(prompt, tokenizer.encode("system")) push!(prompt, 128008) #end_header_id prompt = vcat(prompt, tokenizer.encode(sys_prompt)) @@ -28,12 +28,25 @@ function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) push!(prompt, 128008) #end_header_id prompt = vcat(prompt, tokenizer.encode("\n")) prompt = vcat(prompt, tokenizer.encode(user_prompt)) - prompt = vcat(prompt, [128009, 128007]) #eot_id, start_header_id + prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id prompt = vcat(prompt, tokenizer.encode("assistant")) push!(prompt, 128008) #end_header_id return prompt end +#These have already been incremented by 1 to account for Julia's 1-indexing +special_tokens = Dict( + "<|begin_of_text|>" => 128001, + "<|end_of_text|>" => 128002, + "<|start_header_id|>" => 128007, + "<|end_header_id|>" => 128008, + "<|eot_id|>" => 128010, + "<|finetune_right_pad_id|>" => 128005, + "<|python_tag|>" => 128011 +) + +#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ] + """ Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file. @@ -131,3 +144,4 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 end load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T) + From e84d4f16a675b0ceb9cdaca3983ebdf29aae7a73 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Thu, 14 Nov 2024 13:28:35 +0100 Subject: [PATCH 3/8] Delete Manifest.toml --- Manifest.toml | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 33e1ded..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,7 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.0" -manifest_format = "2.0" -project_hash = "d61a61014ab7dc6e0a8e034fba489fc14f7fd619" - -[deps] From e37ec0353a9aa622538619154215d7a29dc87996 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Thu, 14 Nov 2024 13:30:15 +0100 Subject: [PATCH 4/8] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 0887050..5decb6a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem /docs/Manifest.toml /docs/build/ +/Manifest.toml From 5e790b1ffa33870aab9498a8a34dc88778e9588f Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Thu, 14 Nov 2024 13:38:49 +0100 Subject: [PATCH 5/8] Update CI.yml --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 51b05fd..b00727a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,7 +30,6 @@ jobs: - ubuntu-latest arch: - x64 - - x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 From 7e78f19ba3a7486d0a080438302c9603ca207a50 Mon Sep 17 00:00:00 2001 From: murrellb Date: Thu, 14 Nov 2024 13:41:23 +0100 Subject: [PATCH 6/8] Docstrings. --- src/sampling.jl | 2 ++ src/utils.jl | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/src/sampling.jl b/src/sampling.jl index 9833da3..7473191 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -24,6 +24,8 @@ top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler( # to the sampling of each token. But when I try and fix it, the model gets slightly dumber. # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. """ + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) + Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. diff --git a/src/utils.jl b/src/utils.jl index aec8cc1..76ec305 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,8 @@ llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base") """ + generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) + Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt. tkn = llama3_tokenizer() @@ -12,6 +14,8 @@ assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a h #https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ """ + generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) + Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. tkn = llama3_tokenizer() @@ -49,6 +53,8 @@ special_tokens = Dict( """ + model = load_llama3_from_safetensors(model_weight_paths, config) + Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file. Important note: Huggingface uses a different RoPE convention than other implementations, so if you're loading weights from a different source, you might get very poor model performance. From c5ac2abe937fb116edaa50bd305a18689923678d Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Thu, 14 Nov 2024 13:43:06 +0100 Subject: [PATCH 7/8] Update make.jl --- docs/make.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/make.jl b/docs/make.jl index f08c26d..0a559ba 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,4 +20,6 @@ makedocs(; deploydocs(; repo="github.com/MurrellGroup/Jjama3.jl", devbranch="main", + devurl="dev", + versions = ["stable" => "v^", "v#.#", devurl => devurl], ) From 0ce4a2bcfc14d2d7cdff99b2ebafcd7c38a9f1bc Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Thu, 14 Nov 2024 13:55:47 +0100 Subject: [PATCH 8/8] Update make.jl --- docs/make.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 0a559ba..f08c26d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -20,6 +20,4 @@ makedocs(; deploydocs(; repo="github.com/MurrellGroup/Jjama3.jl", devbranch="main", - devurl="dev", - versions = ["stable" => "v^", "v#.#", devurl => devurl], )