From de27ee76f3c8f0fecbec0887e3dce3f24019ce1e Mon Sep 17 00:00:00 2001 From: murrellb Date: Sun, 24 Nov 2024 00:17:04 +0100 Subject: [PATCH 01/10] Model speedup (needs CUDA testing) --- src/Jjama3.jl | 14 ++++++++++- src/model.jl | 69 +++++++++++++++++++++++++++------------------------ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/src/Jjama3.jl b/src/Jjama3.jl index ac0cf52..90cbd80 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -5,7 +5,19 @@ using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBa include("model.jl") include("utils.jl") include("sampling.jl") +include("tokenizers.jl") -export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler +export load_huggingface_tokenizer_and_encoder, + load_llama321B_from_safetensors, + load_llama3_from_safetensors, + llama3_tokenizer, + assistant_prompt, + format_llama32_instruction_prompt, + generate, + forward_loss, + forward_inference, + top_pk_sampler, + argmax_sampler, + load_huggingface_tokenizer_and_encoder end diff --git a/src/model.jl b/src/model.jl index 1078808..72276db 100644 --- a/src/model.jl +++ b/src/model.jl @@ -45,12 +45,12 @@ end function apply_rotary_emb(x, freqs_cis) # x is (head_dim, seq_len, n_heads, batch) head_dim, seq_len, n_heads, batch = size(x) - x1 = x[1:head_dim÷2, :, :, :] - x2 = x[head_dim÷2+1:end, :, :, :] + x1 = @view x[1:head_dim÷2, :, :, :] + x2 = @view x[head_dim÷2+1:end, :, :, :] cos, sin = freqs_cis - out = vcat( - x1 .* cos[:,1:seq_len,:] .- x2 .* sin[:,1:seq_len,:], - x2 .* cos[:,1:seq_len,:] .+ x1 .* sin[:,1:seq_len,:] + out = vcat( + x1 .* cos .- x2 .* sin, + x2 .* cos .+ x1 .* sin ) return out end @@ -111,16 +111,12 @@ function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv:: cache.cache_v[:, 1:(start_pos+seqlen), :, :] end + function repeat_kv(x::AbstractArray, n_rep::Int) - # x is (head_dim, seq_len, n_kv_heads, batch) - # output should be (head_dim, seq_len, n_rep * n_kv_heads, batch) if n_rep == 1 return x end - head_dim, seq_len, n_kv_heads, batch = size(x) - x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch)) - x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) - return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch)) + return repeat(x, 1, n_rep, 1, 1) end mutable struct Attention @@ -154,27 +150,18 @@ end function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T dim, seqlen, batch = size(x) - # Project Q,K,V xq = attn.wq(x) xk = attn.wk(x) xv = attn.wv(x) - #Reshaping dim: 8, len: 3, batch: 2 - # to: 2, 3, 4, 2 - # RoPE input needs (head_dim, len, n_heads, batch) - - # Reshape to separate heads xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Some GPUs don't like PermutedDimsArray - #xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster... - #xk = PermutedDimsArray(xk, (1,3,2,4)) - #xv = PermutedDimsArray(xv, (1,3,2,4)) - xq = permutedims(xq, (1,3,2,4)) - xk = permutedims(xk, (1,3,2,4)) - xv = permutedims(xv, (1,3,2,4)) + #Lazy permute dims. Need to test CUDA. + xq = PermutedDimsArray(xq, (1,3,2,4)) + xk = PermutedDimsArray(xk, (1,3,2,4)) + xv = PermutedDimsArray(xv, (1,3,2,4)) xq_rope = apply_rotary_emb(xq, freqs_cis) xk_rope = apply_rotary_emb(xk, freqs_cis) @@ -183,28 +170,40 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) end - # Apply GQA via repeat_kv xk_rope = repeat_kv(xk_rope, attn.n_rep) - xv = repeat_kv(xv, attn.n_rep) + xv = repeat_kv(xv, attn.n_rep) xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) + #= scores = batched_mul( - 0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) - 0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads) + permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) + #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) + xk_for_attn # (head_dim, seqlen, batch*heads) ) ./ sqrt(T(attn.head_dim)) if !isnothing(mask) scores = scores .+ mask end - #len: 3, len: 3, headsxbatch: 8 - sm_scores = softmax(scores; dims=2) # Need to get this over dim 1 for efficiency! - #len: 3, head_dim: 2, headsxbatch: 8 + sm_scores = softmax(scores; dims=2) output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) - # Reshape back to separate batch and heads e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) + =# + + scores = batched_mul( + batched_transpose(xk_for_attn), + xq_for_attn + ) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + scores = scores .+ mask + end + sm_scores = softmax(scores; dims=1) + output = batched_mul(xv_for_attn, sm_scores) + e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) + p_output = permutedims(e_output, (1,3,2,4)) + r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) proj = attn.wo(r_output) return proj @@ -277,6 +276,7 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) + #= mask = nothing if seqlen > 1 #mask = fill(T(-Inf), (seqlen, seqlen)) @@ -288,6 +288,8 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st mask = hcat(zeros(T, seqlen, start_pos), mask) end end + =# + mask = create_mask(h) for layer in model.layers h = layer(h, start_pos, freqs_cis, mask) end @@ -301,7 +303,8 @@ function create_mask(h::AbstractArray) mask = similar(h, seqlen, seqlen) T = eltype(h) mask .= T(-Inf) - mask = triu(mask, 1) + #mask = triu(mask, 1) + mask = tril(mask, -1) return mask end From 45b431ee55a55f67437063509c85dbd4f447b939 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sun, 24 Nov 2024 00:19:56 +0100 Subject: [PATCH 02/10] Model tweaks (needs CUDA testing) --- src/model.jl | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/model.jl b/src/model.jl index 72276db..ff85380 100644 --- a/src/model.jl +++ b/src/model.jl @@ -276,19 +276,7 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) - #= - mask = nothing - if seqlen > 1 - #mask = fill(T(-Inf), (seqlen, seqlen)) - mask = similar(h, seqlen, seqlen) - mask .= T(-Inf) - mask = triu(mask, 1) - # Add zeros for cached sequence - if start_pos > 0 - mask = hcat(zeros(T, seqlen, start_pos), mask) - end - end - =# + mask = create_mask(h) for layer in model.layers h = layer(h, start_pos, freqs_cis, mask) From 6288b70e5b419e5d456e0ba3b5bf638134577e9e Mon Sep 17 00:00:00 2001 From: murrellb Date: Mon, 25 Nov 2024 09:58:36 +0100 Subject: [PATCH 03/10] Adding samplers --- Project.toml | 4 +++ ext/MetalExt.jl | 9 ++++++ src/Jjama3.jl | 26 +++++++++++----- src/model.jl | 2 +- src/sampling.jl | 45 ++++++++++++++++++++++++---- src/utils.jl | 80 +++++++++++++++++-------------------------------- 6 files changed, 100 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index 278a277..7c850e7 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "1.0.0-DEV" BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" @@ -15,6 +16,9 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +[sources] +HuggingFaceTokenizers = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl" + [extensions] MetalExt = "Metal" diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 8659a47..e4eae03 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -15,4 +15,13 @@ function NNlib.batched_mul(a::MtlArray, b::MtlArray) return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...) end +function NNlib.PermutedDimsArray(a::MtlArray, perm) + return permutedims(a, perm) +end + +function NNlib.batched_transpose(a::MtlArray) + dims = size(a) + return permutedims(a, (2,1,3:length(dims)...)) +end + end diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 90cbd80..79de914 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,23 +1,33 @@ module Jjama3 -using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib +using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib +import HuggingFaceTokenizers + +tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained +tokenizer_from_file = HuggingFaceTokenizers.from_file +Tokenizer = HuggingFaceTokenizers.Tokenizer include("model.jl") include("utils.jl") include("sampling.jl") -include("tokenizers.jl") -export load_huggingface_tokenizer_and_encoder, - load_llama321B_from_safetensors, +export load_llama321B_from_safetensors, load_llama3_from_safetensors, - llama3_tokenizer, - assistant_prompt, - format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler, - load_huggingface_tokenizer_and_encoder + top_n_sigma_sampler, + min_p_sampler, + tokenizer_from_repo, + tokenizer_from_file, + encode, + decode, + Tokenizer, + llama3_instruct_prompt, + llama3_assistant_prompt, + smollm2_instruct_prompt, + smollm2_assistant_prompt end diff --git a/src/model.jl b/src/model.jl index ff85380..272c902 100644 --- a/src/model.jl +++ b/src/model.jl @@ -292,7 +292,7 @@ function create_mask(h::AbstractArray) T = eltype(h) mask .= T(-Inf) #mask = triu(mask, 1) - mask = tril(mask, -1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup return mask end diff --git a/src/sampling.jl b/src/sampling.jl index 7473191..fbb2fd2 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -20,23 +20,58 @@ end top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) +# https://arxiv.org/pdf/2411.07641 +function top_n_sigma_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T + scaled_logits = logits ./ temperature + M = maximum(scaled_logits) + σ = std(scaled_logits) + threshold = M - n * σ + mask = scaled_logits .>= threshold + masked_logits = copy(scaled_logits) + masked_logits[.!mask] .= -Inf + probs = device(Jjama3.softmax(masked_logits)) + return sample(1:length(probs), Weights(probs)) +end + +top_n_sigma_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_n_sigma_sampler(logits; temperature, n, device) + +#https://arxiv.org/pdf/2407.01082 +function min_p_sampler(logits::AbstractVector{T}; pbase::T = 0.5f0, device = identity) where T + probs = device(Jjama3.softmax(logits)) + pmax = maximum(probs) + pscaled = pbase * pmax + mask = probs .>= pscaled + if !any(mask) + mask[argmax(probs)] = true + end + masked_probs = copy(probs) + masked_probs[.!mask] .= zero(T) + normalization = sum(masked_probs) + if normalization > 0 + masked_probs ./= normalization + end + return sample(1:length(probs), Weights(masked_probs)) +end + +min_p_sampler(; pbase = 0.5f0, device = identity) = logits -> min_p_sampler(logits; pbase, device) + # This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence # to the sampling of each token. But when I try and fix it, the model gets slightly dumber. # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. """ - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. tkn = llama3_tokenizer() - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) """ function generate(model::Transformer{T}, initial_tokens::AbstractArray{IntT}; max_new_tokens=100, sampler::Function=argmax_sampler, - encoder_for_printing = nothing, + tokenizer_for_printing = nothing, end_token = 128010, device = identity) where {T, IntT} @@ -77,8 +112,8 @@ function generate(model::Transformer{T}, next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token - if !isnothing(encoder_for_printing) - print(encoder_for_printing.decode([next_token])) + if !isnothing(tokenizer_for_printing) + print(decode(tokenizer_for_printing, [next_token])) end if next_token == end_token break diff --git a/src/utils.jl b/src/utils.jl index 29de607..69a1c9b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,64 +1,34 @@ -""" - tkn = llama3_tokenizer() +encode(tkn::Tokenizer, str) = HuggingFaceTokenizers.encode(tkn, str).ids .+ 1 +decode(tkn::Tokenizer, ids) = HuggingFaceTokenizers.decode(tkn, ids .- 1) -Load the tokenizer for Llama3. This seems to work, but I have not checked if there are some different edge-cases, or missing tokens relative to the original tokenizer (besides the special tokens we hackily include). - tkn = llama3_tokenizer() - tkn.encode("What is the capital of France?") - tkn.decode([10, 2, 5, 99]) -""" -llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base") +#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ +function llama3_instruct_prompt(tokenizer,system_prompt, user_prompt) + str = """<|start_header_id|>system<|end_header_id|> +$system_prompt +<|eot_id|><|start_header_id|>user<|end_header_id|> + +$(user_prompt)<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" + return encode(tokenizer, str) +end """ generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt. - tkn = llama3_tokenizer() - prompt = assistant_prompt("What is the capital of France?", tkn) + prompt = assistant_prompt(tkn, "What is the capital of France?") generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) """ -assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn); - +llama3_assistant_prompt(tokenizer, prompt) = llama3_instruct_prompt(tokenizer,"\nYou are a helpful assistant\n", prompt); -#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ -""" - generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) - -Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. - - tkn = llama3_tokenizer() - prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn) - generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) -""" -function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) - prompt = [128001, 128007] #begin_of_text, start_header_id - prompt = vcat(prompt, tokenizer.encode("system")) - push!(prompt, 128008) #end_header_id - prompt = vcat(prompt, tokenizer.encode(sys_prompt)) - prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id - prompt = vcat(prompt, tokenizer.encode("user")) - push!(prompt, 128008) #end_header_id - prompt = vcat(prompt, tokenizer.encode("\n")) - prompt = vcat(prompt, tokenizer.encode(user_prompt)) - prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id - prompt = vcat(prompt, tokenizer.encode("assistant")) - push!(prompt, 128008) #end_header_id - return prompt +function smollm2_instruct_prompt(tokenizer, system_prompt, user_prompt) + str = """<|im_start|>system\n$(system_prompt)<|im_end|>\n<|im_start|>user\n$(user_prompt)<|im_end|>\n""" + return encode(tokenizer, str) end -#These have already been incremented by 1 to account for Julia's 1-indexing -special_tokens = Dict( - "<|begin_of_text|>" => 128001, - "<|end_of_text|>" => 128002, - "<|start_header_id|>" => 128007, - "<|end_header_id|>" => 128008, - "<|eot_id|>" => 128010, - "<|finetune_right_pad_id|>" => 128005, - "<|python_tag|>" => 128011 -) +smollm2_assistant_prompt(tokenizer, prompt) = smollm2_instruct_prompt(tokenizer, "You are a helpful AI assistant named SmolLM, trained by Hugging Face", prompt); -#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ] """ @@ -75,12 +45,18 @@ so if you're loading weights from a different source, you might get very poor mo """ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32) config = Dict(config) #Just in case the user passed eg. a JSON3.Object - @assert config[:rope_scaling][:rope_type] == "llama3" - @assert config[:rope_scaling][:low_freq_factor] == 1 - @assert config[:rope_scaling][:high_freq_factor] == 4 - @assert config[:rope_scaling][:original_max_position_embeddings] == 8192 + #@assert config[:rope_scaling][:rope_type] == "llama3" + #@assert config[:rope_scaling][:low_freq_factor] == 1 + #@assert config[:rope_scaling][:high_freq_factor] == 4 + #@assert config[:rope_scaling][:original_max_position_embeddings] == 8192 # Create model with config parameters from the JSON + scale_factor = 1f0 + if haskey(config, :rope_scaling) + if !isnothing(config[:rope_scaling]) + scale_factor = config[:rope_scaling][:factor] + end + end model = Transformer( config[:vocab_size], # vocab_size config[:hidden_size], # dim (hidden_size) @@ -92,7 +68,7 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 norm_eps=T(config[:rms_norm_eps]), # rms_norm_eps rope_theta=T(config[:rope_theta]), # rope_theta use_scaled_rope=true, # Using scaled RoPE based on the config - scale_factor=config[:rope_scaling][:factor] # scale_factor + scale_factor=scale_factor # scale_factor ) for path in paths # Process one file at a time From 34f61affa8d19415b3b5f48dc94c118d41ae6501 Mon Sep 17 00:00:00 2001 From: murrellb Date: Mon, 25 Nov 2024 09:59:21 +0100 Subject: [PATCH 04/10] Renaming --- src/Jjama3.jl | 2 +- src/sampling.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 79de914..5680530 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -18,7 +18,7 @@ export load_llama321B_from_safetensors, forward_inference, top_pk_sampler, argmax_sampler, - top_n_sigma_sampler, + top_nσ_sampler, min_p_sampler, tokenizer_from_repo, tokenizer_from_file, diff --git a/src/sampling.jl b/src/sampling.jl index fbb2fd2..97a736f 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -21,7 +21,7 @@ end top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) # https://arxiv.org/pdf/2411.07641 -function top_n_sigma_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T +function top_nσ_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T scaled_logits = logits ./ temperature M = maximum(scaled_logits) σ = std(scaled_logits) @@ -33,7 +33,7 @@ function top_n_sigma_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, return sample(1:length(probs), Weights(probs)) end -top_n_sigma_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_n_sigma_sampler(logits; temperature, n, device) +top_nσ_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_nσ_sampler(logits; temperature, n, device) #https://arxiv.org/pdf/2407.01082 function min_p_sampler(logits::AbstractVector{T}; pbase::T = 0.5f0, device = identity) where T From 604eb5d9afbdc8ea8868b61f7925014f26bf88a0 Mon Sep 17 00:00:00 2001 From: murrellb Date: Mon, 25 Nov 2024 15:58:58 +0100 Subject: [PATCH 05/10] Moving samplers to its own thing --- Project.toml | 8 +++++--- src/Jjama3.jl | 7 ++++++- src/model.jl | 5 +---- src/sampling.jl | 54 ------------------------------------------------- 4 files changed, 12 insertions(+), 62 deletions(-) diff --git a/Project.toml b/Project.toml index 7c850e7..f9f0f2e 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -16,9 +17,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Metal = "dde4c033-4e86-420c-a63e-0dd931031962" -[sources] -HuggingFaceTokenizers = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl" - [extensions] MetalExt = "Metal" @@ -34,5 +32,9 @@ julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +HuggingFaceTokenizers = {url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl", rev = "main"} +LogitSamplers = {url = "https://github.com/MurrellGroup/LogitSamplers.jl", rev = "main"} + [targets] test = ["Test"] diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 5680530..7f4d8a0 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,12 +1,17 @@ module Jjama3 using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib -import HuggingFaceTokenizers +import HuggingFaceTokenizers, LogitSamplers tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained tokenizer_from_file = HuggingFaceTokenizers.from_file Tokenizer = HuggingFaceTokenizers.Tokenizer +top_pk_sampler = LogitSamplers.top_pk_sampler +argmax_sampler = LogitSamplers.argmax_sampler +min_p_sampler = LogitSamplers.min_p_sampler +top_nσ_sampler = LogitSamplers.top_nσ_sampler + include("model.jl") include("utils.jl") include("sampling.jl") diff --git a/src/model.jl b/src/model.jl index 272c902..90d4f66 100644 --- a/src/model.jl +++ b/src/model.jl @@ -192,10 +192,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) =# - scores = batched_mul( - batched_transpose(xk_for_attn), - xq_for_attn - ) ./ sqrt(T(attn.head_dim)) + scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) if !isnothing(mask) scores = scores .+ mask end diff --git a/src/sampling.jl b/src/sampling.jl index 97a736f..50c10e2 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,59 +1,5 @@ -function argmax_sampler(logits::AbstractVector; device = identity) - return argmax(device(logits)) -end - -argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device) - -function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity) - probs = device(Jjama3.softmax(logits)) - perm = partialsortperm(probs, 1:k, rev=true) - sorted_probs = probs[perm] - cumsum_probs = cumsum(sorted_probs) - if cumsum_probs[1] > p - return perm[1] - else - cutoff = findlast(cumsum_probs .< p) - return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff])) - end -end - -top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) - -# https://arxiv.org/pdf/2411.07641 -function top_nσ_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T - scaled_logits = logits ./ temperature - M = maximum(scaled_logits) - σ = std(scaled_logits) - threshold = M - n * σ - mask = scaled_logits .>= threshold - masked_logits = copy(scaled_logits) - masked_logits[.!mask] .= -Inf - probs = device(Jjama3.softmax(masked_logits)) - return sample(1:length(probs), Weights(probs)) -end - -top_nσ_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_nσ_sampler(logits; temperature, n, device) - -#https://arxiv.org/pdf/2407.01082 -function min_p_sampler(logits::AbstractVector{T}; pbase::T = 0.5f0, device = identity) where T - probs = device(Jjama3.softmax(logits)) - pmax = maximum(probs) - pscaled = pbase * pmax - mask = probs .>= pscaled - if !any(mask) - mask[argmax(probs)] = true - end - masked_probs = copy(probs) - masked_probs[.!mask] .= zero(T) - normalization = sum(masked_probs) - if normalization > 0 - masked_probs ./= normalization - end - return sample(1:length(probs), Weights(masked_probs)) -end -min_p_sampler(; pbase = 0.5f0, device = identity) = logits -> min_p_sampler(logits; pbase, device) # This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence # to the sampling of each token. But when I try and fix it, the model gets slightly dumber. From aa151b29c67196854046aadd70f97d54aa85ee6f Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 26 Nov 2024 16:35:29 +0100 Subject: [PATCH 06/10] Refactor, adding LoRA and structured sampling example. --- ext/MetalExt.jl | 2 +- src/Jjama3.jl | 10 ++- src/layers.jl | 191 +++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 212 ++---------------------------------------------- src/sampling.jl | 3 - src/utils.jl | 95 ++++++++++++++++++++-- 6 files changed, 297 insertions(+), 216 deletions(-) create mode 100644 src/layers.jl diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index e4eae03..5383601 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -3,7 +3,7 @@ module MetalExt #Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. #It seems that each single Metal call has some constant overhead that kills it. -using Metal, Jjama3.NNlib +using Metal, NNlib function NNlib.batched_mul(a::MtlArray, b::MtlArray) a_shape = size(a) diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 7f4d8a0..998a799 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,7 +1,9 @@ module Jjama3 using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib -import HuggingFaceTokenizers, LogitSamplers +using LogitSamplers, LowRankLayers +import HuggingFaceTokenizers + tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained tokenizer_from_file = HuggingFaceTokenizers.from_file @@ -12,6 +14,9 @@ argmax_sampler = LogitSamplers.argmax_sampler min_p_sampler = LogitSamplers.min_p_sampler top_nσ_sampler = LogitSamplers.top_nσ_sampler + + +include("layers.jl") include("model.jl") include("utils.jl") include("sampling.jl") @@ -33,6 +38,7 @@ export load_llama321B_from_safetensors, llama3_instruct_prompt, llama3_assistant_prompt, smollm2_instruct_prompt, - smollm2_assistant_prompt + smollm2_assistant_prompt, + structured_choice end diff --git a/src/layers.jl b/src/layers.jl new file mode 100644 index 0000000..130d4b5 --- /dev/null +++ b/src/layers.jl @@ -0,0 +1,191 @@ +struct KVCache{T} + cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) + cache_v::AbstractArray{T, 4} +end + +function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device + KVCache(cache_k, cache_v) +end + +struct FeedForward + w1::Union{Dense, LoRADense} + w2::Union{Dense, LoRADense} + w3::Union{Dense, LoRADense} +end + +function FeedForward(dim::Int, ff_hidden_dim::Int) + FeedForward( + Dense(dim => ff_hidden_dim, bias=false), + Dense(ff_hidden_dim => dim, bias=false), + Dense(dim => ff_hidden_dim, bias=false) + ) +end + +function (ff::FeedForward)(x) + return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) +end + +Flux.@layer :expand FeedForward + +struct RMSNorm{T} + weight::AbstractVector{T} + eps::T +end + +function RMSNorm(dim::Int; eps::T=1f-5) where T + RMSNorm{T}(ones(T, dim), eps) +end + +function (norm::RMSNorm)(x) + rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) + return x .* (norm.weight ./ rms) +end + +Flux.@layer RMSNorm + +mutable struct Attention + wq::Union{Dense, LoRADense} + wk::Union{Dense, LoRADense} + wv::Union{Dense, LoRADense} + wo::Union{Dense, LoRADense} + n_heads::Int + n_kv_heads::Int + head_dim::Int + n_rep::Int + cache::Union{Nothing, KVCache} +end + +function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads) + head_dim = dim ÷ n_heads + n_rep = n_heads ÷ n_kv_heads + Attention( + Dense(dim => n_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(n_heads * head_dim => dim, bias=false), + n_heads, + n_kv_heads, + head_dim, + n_rep, + nothing + ) +end + +function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T + dim, seqlen, batch = size(x) + + xq = attn.wq(x) + xk = attn.wk(x) + xv = attn.wv(x) + + xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) + xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + + #Lazy permute dims. Need to test CUDA. + xq = PermutedDimsArray(xq, (1,3,2,4)) + xk = PermutedDimsArray(xk, (1,3,2,4)) + xv = PermutedDimsArray(xv, (1,3,2,4)) + + xq_rope = apply_rotary_emb(xq, freqs_cis) + xk_rope = apply_rotary_emb(xk, freqs_cis) + + if !isnothing(attn.cache) + xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) + end + + xk_rope = repeat_kv(xk_rope, attn.n_rep) + xv = repeat_kv(xv, attn.n_rep) + + xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) + xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) + xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) + + #= + scores = batched_mul( + permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) + #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) + xk_for_attn # (head_dim, seqlen, batch*heads) + ) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + scores = scores .+ mask + end + sm_scores = softmax(scores; dims=2) + output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) + e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) + p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) + =# + + scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + scores = scores .+ mask + end + sm_scores = softmax(scores; dims=1) + output = batched_mul(xv_for_attn, sm_scores) + e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) + p_output = permutedims(e_output, (1,3,2,4)) + + r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) + proj = attn.wo(r_output) + return proj +end + +Flux.@layer :expand Attention trainable=(wq, wv) + +struct TransformerBlock + attention::Attention + feed_forward::FeedForward + attention_norm::RMSNorm + ffn_norm::RMSNorm +end + +function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; + norm_eps=1f-5) + TransformerBlock( + Attention(dim, n_heads, n_kv_heads), + FeedForward(dim, ff_hidden_dim), + RMSNorm(dim, eps=norm_eps), + RMSNorm(dim, eps=norm_eps) + ) +end + +function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) + h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) + out = h + block.feed_forward(block.ffn_norm(h)) + return out +end + +Flux.@layer TransformerBlock trainable=(attention, ) + +struct Transformer{T} + tok_embeddings::Flux.Embedding + layers::AbstractVector{TransformerBlock} + norm::RMSNorm{T} + output::Dense + freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} +end + +function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, + n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; + norm_eps::T=1f-5, + rope_theta::T=500000f0, + use_scaled_rope=false, + scale_factor=8) where T + + tok_embeddings = Flux.Embedding(vocab_size => dim) + layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers] + norm = RMSNorm(dim, eps=norm_eps) + output = Dense(dim => vocab_size, bias=false) + freqs_cis = precompute_freqs_cis( + dim ÷ n_heads, + max_seq_len * 2; + theta=rope_theta, + use_scaled=use_scaled_rope, + scale_factor=scale_factor + ) + Transformer(tok_embeddings, layers, norm, output, freqs_cis) +end + +Flux.@layer :expand Transformer trainable=(layers, ) diff --git a/src/model.jl b/src/model.jl index 90d4f66..9fdb3c4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -43,7 +43,6 @@ end #Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #Use this one if you're using the Hugging Face weights. function apply_rotary_emb(x, freqs_cis) - # x is (head_dim, seq_len, n_heads, batch) head_dim, seq_len, n_heads, batch = size(x) x1 = @view x[1:head_dim÷2, :, :, :] x2 = @view x[head_dim÷2+1:end, :, :, :] @@ -55,54 +54,6 @@ function apply_rotary_emb(x, freqs_cis) return out end - -struct FeedForward - w1::Dense - w2::Dense - w3::Dense -end - -function FeedForward(dim::Int, ff_hidden_dim::Int) - FeedForward( - Dense(dim => ff_hidden_dim, bias=false), - Dense(ff_hidden_dim => dim, bias=false), - Dense(dim => ff_hidden_dim, bias=false) - ) -end - -function (ff::FeedForward)(x) - return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -end - -Flux.@layer :expand FeedForward - -struct RMSNorm{T} - weight::AbstractVector{T} - eps::T -end - -function RMSNorm(dim::Int; eps::T=1f-5) where T - RMSNorm{T}(ones(T, dim), eps) -end - -function (norm::RMSNorm)(x) - rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) - return x .* (norm.weight ./ rms) -end - -Flux.@layer RMSNorm - -struct KVCache{T} - cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) - cache_v::AbstractArray{T, 4} -end - -function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) - cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - KVCache(cache_k, cache_v) -end - function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) seqlen = size(xk, 2) cache.cache_k[:, (start_pos+1):(start_pos+seqlen), :, :] .= xk @@ -111,7 +62,6 @@ function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv:: cache.cache_v[:, 1:(start_pos+seqlen), :, :] end - function repeat_kv(x::AbstractArray, n_rep::Int) if n_rep == 1 return x @@ -119,152 +69,6 @@ function repeat_kv(x::AbstractArray, n_rep::Int) return repeat(x, 1, n_rep, 1, 1) end -mutable struct Attention - wq::Dense - wk::Dense - wv::Dense - wo::Dense - n_heads::Int - n_kv_heads::Int - head_dim::Int - n_rep::Int - cache::Union{Nothing, KVCache} -end - -function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads) - head_dim = dim ÷ n_heads - n_rep = n_heads ÷ n_kv_heads - Attention( - Dense(dim => n_heads * head_dim, bias=false), - Dense(dim => n_kv_heads * head_dim, bias=false), - Dense(dim => n_kv_heads * head_dim, bias=false), - Dense(n_heads * head_dim => dim, bias=false), - n_heads, - n_kv_heads, - head_dim, - n_rep, - nothing - ) -end - -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T - dim, seqlen, batch = size(x) - - xq = attn.wq(x) - xk = attn.wk(x) - xv = attn.wv(x) - - xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) - xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - - #Lazy permute dims. Need to test CUDA. - xq = PermutedDimsArray(xq, (1,3,2,4)) - xk = PermutedDimsArray(xk, (1,3,2,4)) - xv = PermutedDimsArray(xv, (1,3,2,4)) - - xq_rope = apply_rotary_emb(xq, freqs_cis) - xk_rope = apply_rotary_emb(xk, freqs_cis) - - if !isnothing(attn.cache) - xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) - end - - xk_rope = repeat_kv(xk_rope, attn.n_rep) - xv = repeat_kv(xv, attn.n_rep) - - xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) - xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) - xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - - #= - scores = batched_mul( - permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) - #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) - xk_for_attn # (head_dim, seqlen, batch*heads) - ) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end - sm_scores = softmax(scores; dims=2) - output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) - e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) - p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) - =# - - scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end - sm_scores = softmax(scores; dims=1) - output = batched_mul(xv_for_attn, sm_scores) - e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) - p_output = permutedims(e_output, (1,3,2,4)) - - r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) - proj = attn.wo(r_output) - return proj -end - -Flux.@layer :expand Attention - -struct TransformerBlock - attention::Attention - feed_forward::FeedForward - attention_norm::RMSNorm - ffn_norm::RMSNorm -end - -function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5) - TransformerBlock( - Attention(dim, n_heads, n_kv_heads), - FeedForward(dim, ff_hidden_dim), - RMSNorm(dim, eps=norm_eps), - RMSNorm(dim, eps=norm_eps) - ) -end - -function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) - h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) - out = h + block.feed_forward(block.ffn_norm(h)) - return out -end - -Flux.@layer TransformerBlock - -struct Transformer{T} - tok_embeddings::Flux.Embedding - layers::AbstractVector{TransformerBlock} - norm::RMSNorm{T} - output::Dense - freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} -end - -function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, - n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; - norm_eps::T=1f-5, - rope_theta::T=500000f0, - use_scaled_rope=false, - scale_factor=8) where T - - tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers] - norm = RMSNorm(dim, eps=norm_eps) - output = Dense(dim => vocab_size, bias=false) - freqs_cis = precompute_freqs_cis( - dim ÷ n_heads, - max_seq_len * 2; - theta=rope_theta, - use_scaled=use_scaled_rope, - scale_factor=scale_factor - ) - Transformer(tok_embeddings, layers, norm, output, freqs_cis) -end - -Flux.@layer :expand Transformer trainable=(layers, norm) - - function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, start_pos::Int) where T seqlen = size(tokens, 1) # tokens expected as (seq_len, batch) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) @@ -284,13 +88,15 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st end function create_mask(h::AbstractArray) - embeddim, seqlen, batch = size(h) - mask = similar(h, seqlen, seqlen) - T = eltype(h) - mask .= T(-Inf) - #mask = triu(mask, 1) - mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup - return mask + Flux.Zygote.ignore() do + embeddim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + T = eltype(h) + mask .= T(-Inf) + #mask = triu(mask, 1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup + return mask + end end function forward_loss(model::Transformer{T}, inputs::AbstractArray, diff --git a/src/sampling.jl b/src/sampling.jl index 50c10e2..ee51d3d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,6 +1,3 @@ - - - # This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence # to the sampling of each token. But when I try and fix it, the model gets slightly dumber. # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. diff --git a/src/utils.jl b/src/utils.jl index 69a1c9b..3c56b41 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,5 @@ -encode(tkn::Tokenizer, str) = HuggingFaceTokenizers.encode(tkn, str).ids .+ 1 -decode(tkn::Tokenizer, ids) = HuggingFaceTokenizers.decode(tkn, ids .- 1) +encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 +decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) #https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ @@ -23,14 +23,12 @@ Format a prompt for use with Llama3.2's instruction format, with a simple "You a llama3_assistant_prompt(tokenizer, prompt) = llama3_instruct_prompt(tokenizer,"\nYou are a helpful assistant\n", prompt); function smollm2_instruct_prompt(tokenizer, system_prompt, user_prompt) - str = """<|im_start|>system\n$(system_prompt)<|im_end|>\n<|im_start|>user\n$(user_prompt)<|im_end|>\n""" + str = """<|im_start|>system\n$(system_prompt)<|im_end|>\n<|im_start|>user\n$(user_prompt)<|im_end|>\n<|im_start|>assistant\n""" return encode(tokenizer, str) end smollm2_assistant_prompt(tokenizer, prompt) = smollm2_instruct_prompt(tokenizer, "You are a helpful AI assistant named SmolLM, trained by Hugging Face", prompt); - - """ model = load_llama3_from_safetensors(model_weight_paths, config) @@ -43,7 +41,7 @@ so if you're loading weights from a different source, you might get very poor mo model_weight_paths = ["Llama3_2_1B_instruct/model.safetensors"] #Can be an array of paths if the model is split across multiple files model = load_llama3_from_safetensors(model_weight_paths, config) """ -function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32) +function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32, add_lora_to = Symbol[], lora_dim = 0) config = Dict(config) #Just in case the user passed eg. a JSON3.Object #@assert config[:rope_scaling][:rope_type] == "llama3" #@assert config[:rope_scaling][:low_freq_factor] == 1 @@ -128,9 +126,92 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 weights = nothing GC.gc() end - + + if !isempty(add_lora_to) + #Then load in the current layers: + if :Q in add_lora_to + for layer in model.layers + layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) + end + end + if :K in add_lora_to + for layer in model.layers + layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) + end + end + if :V in add_lora_to + for layer in model.layers + layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) + end + end + if :O in add_lora_to + for layer in model.layers + layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) + end + end + if :w1 in add_lora_to + for layer in model.layers + layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) + end + end + if :w2 in add_lora_to + for layer in model.layers + layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) + end + end + if :w3 in add_lora_to + for layer in model.layers + layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) + end + end + end return model end load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T) + +""" + sampler = structured_choice(choices, vocab::Vector{String}, end_token::Int; sampler = logits -> argmax_sampler(logits)) + +Return a function that can be passed into generate as a sampler, which will sample from the given choices. Handles the case where the choices are made up of multiple tokens. +`vocab` is an array of the tokens as strings, in their order in the tokenizer. `sampler` is a function that takes the logits (here including those masked with -Inf) and returns a sample from them. Defaults to argmax. + +Example: +```julia +config = JSON3.read(read("SmolLM2-1.7B-Instruct/config.json", String)) +model = load_llama3_from_safetensors("SmolLM2-1.7B-Instruct/model.safetensors", config) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-1.7B-Instruct/tokenizer.json") + +question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" +choices = ["Prior", "Likelihood", "Marginal Likelihood", "Evidence", "Posterior"] + +vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:49152] +eos = encode(tkn, "<|im_end|>")[end] +prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question) +generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos)); +``` +""" +function structured_choice(choices::Vector{String}, vocab::Vector{String}, end_token::Int; sampler = logits -> argmax_sampler(logits), device = identity) + remaining_choices = copy(choices) + function choice_sampler(logits) + logits = device(logits) + if length(remaining_choices) == 0 || maximum(length.(remaining_choices)) == 0 + return end_token + end + mask = zeros(Bool, length(vocab)) + for i in 1:length(vocab) + for choice in remaining_choices + if startswith(choice, vocab[i]) + mask[i] = true + end + end + end + logits[.!mask] .= -Inf + next_token = sampler(logits) + next_token_str = vocab[next_token] + remaining_choices = [choice[length(next_token_str)+1:end] for choice in remaining_choices if startswith(choice, next_token_str)] + return next_token + end + return choice_sampler +end From 1c1bc101401336d76fbe6fa5b95e350b13f69217 Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 26 Nov 2024 16:57:49 +0100 Subject: [PATCH 07/10] Forgotten Project.toml --- Project.toml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index f9f0f2e..376aa22 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" +LowRankLayers = "b66182ab-a85c-43b0-99bd-d85cc47c5e50" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -17,6 +18,11 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +[sources] +HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} +LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"} +LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} + [extensions] MetalExt = "Metal" @@ -24,6 +30,7 @@ MetalExt = "Metal" BytePairEncoding = "0.5" Distributions = "0.25" Flux = "0.14" +LowRankLayers = "1.0.0" NNlib = "0.9" SafeTensors = "1" StatsBase = "0.34" @@ -32,9 +39,5 @@ julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[sources] -HuggingFaceTokenizers = {url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl", rev = "main"} -LogitSamplers = {url = "https://github.com/MurrellGroup/LogitSamplers.jl", rev = "main"} - [targets] test = ["Test"] From 4e12dccc9898e8bfeed3a775949f7b6c933aad9e Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 26 Nov 2024 18:02:40 +0100 Subject: [PATCH 08/10] ...and the readme --- README.md | 90 +++++++++++++++++++++++++++++++++++++++++++++------- src/utils.jl | 2 +- 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a6ccff1..74d5566 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,100 @@ # Jjama3 - Hackable Llama3.1 and Llama3.2 (text) in Julia -[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/) [![Build Status](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/Jjama3.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Jjama3.jl) -# Installation +## Installation + +We've split this into a few (unregistered) packages, so you'll need to add them all: ```julia +] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl +] add https://github.com/MurrellGroup/LowRankLayers.jl +] add https://github.com/MurrellGroup/LogitSamplers.jl ] add https://github.com/MurrellGroup/Jjama3.jl ``` -# Quickstart +## Quickstart -Download a Llama3 model `config.json` and safetensor weights from Huggingface. Eg. [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct). You might need access permissions for this. Note: Huggingface use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original weights you'll need to permute them. +Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible. ```julia -config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String)); -model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config); -tkn = llama3_tokenizer(); -prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn); -ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn); +config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) +model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") + +prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); +ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]); ``` -# Capability +## Capability - Seems to generate reasonable text from Llama3.1 and Llama3.2 models, loaded from Huggingface safetensors. - Sampling accelerated with KV caching, with argmax and top-p sampling supported. - Gradients seem to work on CPU, using Flux and Zygote. Untested on GPU. - Sampling (and forward passes) work with CUDA, where everything is much faster. Gradients untested. -- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is much slower with Metal than with CPU. +- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is slower with Metal than with CPU. + + +## Samplers + +The transformer emits "logits" which control the probability of the next token. A sampler takes these logits and converts them into a probability distribution over the vocabulary, and then samples from this distribution. There are [a few samplers available](https://github.com/MurrellGroup/LogitSamplers.jl), including argmax, top-p, top-k, min-p, and top-nσ. These can substantially affect the output of the model. + +```julia +prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); +ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler()); +``` + +## Structured Sampling + +You can pass in a custom sampler that places additional constraints on the sampling process. As an example, `structured_choice` is a sampler that always selects from a set of predefined options: + +```julia +question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" +choices = ["Prior", "Likelihood", "Marginal Likelihood", "Evidence", "Posterior", "Margin Call"] +vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)] +eos = encode(tkn, "<|im_end|>")[end] +prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question) +ts = generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos)); +``` + +This strategy can be extended to force the model outputs to follow specific formats. + +## Finetuning + +Often we want to adjust model parameters to better fit our specific use case, by further training the model on a new dataset. This can be done on all the model weights, but we also provide low-rank (via LoRA) finetuning. + +```julia +using Jjama3, JSON3, Flux +config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) +tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") + +#Add LoRA to Q and V matrices when loading the model +model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, add_lora_to = [:Q, :V], lora_dim = 64) + +#Set up a single, very silly, training example to finetune on +prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?"); +ts = generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]); +trainsample = decode(tkn,prompt, skip_special_tokens = false) * "Ugh, bruh, what a stupid question.<|im_end|>"; +train_toks = encode(tkn, trainsample); + +#Set up the optimizer +opt_state = Flux.setup(AdamW(0.001f0), model); + +#Train for 5 steps +for i in 1:5 + grads = Flux.gradient(model) do m + forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:]) + end + Flux.update!(opt_state, model, grads[1]) + println(i) + generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]) + println() +end + +#Ask the model an unrelated question: +prompt = smollm2_assistant_prompt(tkn, "Can you explain how tides work?"); +generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler()); +``` + diff --git a/src/utils.jl b/src/utils.jl index 3c56b41..cd1df4b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -168,7 +168,7 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 return model end -load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T) +load_llama3_from_safetensors(path::String, config; T = Float32, kwargs...) = load_llama3_from_safetensors([path], config; T = T, kwargs...) """ From afd58bc59547586659dc8ff39b7f261d6c229766 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 26 Nov 2024 18:18:00 +0100 Subject: [PATCH 09/10] const aliases, ignore .CondaPkg --- .gitignore | 1 + src/Jjama3.jl | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 5decb6a..0a51603 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ /docs/Manifest.toml /docs/build/ /Manifest.toml +.CondaPkg \ No newline at end of file diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 998a799..f0db5b6 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -5,14 +5,14 @@ using LogitSamplers, LowRankLayers import HuggingFaceTokenizers -tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained -tokenizer_from_file = HuggingFaceTokenizers.from_file -Tokenizer = HuggingFaceTokenizers.Tokenizer - -top_pk_sampler = LogitSamplers.top_pk_sampler -argmax_sampler = LogitSamplers.argmax_sampler -min_p_sampler = LogitSamplers.min_p_sampler -top_nσ_sampler = LogitSamplers.top_nσ_sampler +const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained +const tokenizer_from_file = HuggingFaceTokenizers.from_file +const Tokenizer = HuggingFaceTokenizers.Tokenizer + +const top_pk_sampler = LogitSamplers.top_pk_sampler +const argmax_sampler = LogitSamplers.argmax_sampler +const min_p_sampler = LogitSamplers.min_p_sampler +const top_nσ_sampler = LogitSamplers.top_nσ_sampler From b563a47ffa12d96a715a69c0a67df84d39f550af Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 26 Nov 2024 19:01:10 +0100 Subject: [PATCH 10/10] Readme tweaks --- README.md | 68 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 74d5566..e547dc3 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,9 @@ ## Installation -We've split this into a few (unregistered) packages, so you'll need to add them all: +We've split this into a few (unregistered) packages, so you'll need to add them all, and you need JSON3 for loading the configs: ```julia +] add JSON3 ] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl ] add https://github.com/MurrellGroup/LowRankLayers.jl ] add https://github.com/MurrellGroup/LogitSamplers.jl @@ -20,12 +21,17 @@ We've split this into a few (unregistered) packages, so you'll need to add them Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible. ```julia +using JSON3, Jjama3 + config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config) tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") -prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); -ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]); +prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.") +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = encode(tkn, "<|im_end|>")[end]); ``` ## Capability @@ -43,7 +49,12 @@ The transformer emits "logits" which control the probability of the next token. ```julia prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); -ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler()); + +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = encode(tkn, "<|im_end|>")[end], + sampler = top_nσ_sampler()); ``` ## Structured Sampling @@ -52,11 +63,23 @@ You can pass in a custom sampler that places additional constraints on the sampl ```julia question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" -choices = ["Prior", "Likelihood", "Marginal Likelihood", "Evidence", "Posterior", "Margin Call"] +choices = ["Prior", + "Likelihood", + "Marginal Likelihood", + "Evidence", + "Posterior"] + vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)] eos = encode(tkn, "<|im_end|>")[end] -prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question) -ts = generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos)); + +sysprompt = "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible." +prompt = smollm2_instruct_prompt(tkn, sysprompt, question) + +generate(model, prompt, + max_new_tokens=100, + tokenizer_for_printing=tkn, + end_token = eos, + sampler = structured_choice(choices, vocab, eos)); ``` This strategy can be extended to force the model outputs to follow specific formats. @@ -67,34 +90,47 @@ Often we want to adjust model parameters to better fit our specific use case, by ```julia using Jjama3, JSON3, Flux + config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") +eos = encode(tkn, "<|im_end|>")[end] #Add LoRA to Q and V matrices when loading the model -model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, add_lora_to = [:Q, :V], lora_dim = 64) +model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, + add_lora_to = [:Q, :V], lora_dim = 64) -#Set up a single, very silly, training example to finetune on +#See how the model answers before finetuning prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?"); -ts = generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]); -trainsample = decode(tkn,prompt, skip_special_tokens = false) * "Ugh, bruh, what a stupid question.<|im_end|>"; +generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = eos); + +#Set up a single, very silly, training example to finetune on +ugh = "Ugh, bruh, what a stupid question.<|im_end|>" +trainsample = decode(tkn, prompt, skip_special_tokens = false) * ugh; train_toks = encode(tkn, trainsample); #Set up the optimizer opt_state = Flux.setup(AdamW(0.001f0), model); -#Train for 5 steps +#Train for 5 steps, monitoring the model's output as it tunes for i in 1:5 grads = Flux.gradient(model) do m forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:]) end Flux.update!(opt_state, model, grads[1]) println(i) - generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]) + generate(model, prompt, + max_new_tokens=50, + tokenizer_for_printing=tkn, + end_token = eos) println() end -#Ask the model an unrelated question: -prompt = smollm2_assistant_prompt(tkn, "Can you explain how tides work?"); -generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler()); +#Ask the model an unrelated question to see how stupid we've made the model. Try this a few times. +prompt = smollm2_assistant_prompt(tkn, "Explain how tides work?"); +generate(model, prompt, + max_new_tokens=500, + tokenizer_for_printing=tkn, + end_token = eos, + sampler = top_nσ_sampler()); ```