From 0f6433b269ffc90001ee060c555bd94e0d4da6a9 Mon Sep 17 00:00:00 2001 From: murrellb Date: Wed, 13 Nov 2024 13:22:44 +0100 Subject: [PATCH] First commit. CPU working. GPU not tested. --- Project.toml | 7 + README.md | 10 ++ src/Jjama3.jl | 8 +- src/model.jl | 415 ++++++++++++++++++++++++++++++++++++++++++++++++ src/sampling.jl | 63 ++++++++ src/utils.jl | 133 ++++++++++++++++ 6 files changed, 635 insertions(+), 1 deletion(-) create mode 100644 src/model.jl create mode 100644 src/sampling.jl create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index f52f7c4..b833e42 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,13 @@ uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592" authors = ["murrellb and contributors"] version = "1.0.0-DEV" +[deps] +BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" + [compat] julia = "1.9" diff --git a/README.md b/README.md index 353de07..42d6998 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,13 @@ [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/) [![Build Status](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/Jjama3.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Jjama3.jl) + +# Quickstart + +```julia +config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String)); +model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config); +tkn = llama3_tokenizer(); +prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn); +ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn); +``` \ No newline at end of file diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 3fe67c4..a59ac18 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,5 +1,11 @@ module Jjama3 -# Write your package code here. +using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra + +include("model.jl") +include("utils.jl") +include("sampling.jl") + +export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference end diff --git a/src/model.jl b/src/model.jl new file mode 100644 index 0000000..ca2657f --- /dev/null +++ b/src/model.jl @@ -0,0 +1,415 @@ +#Important about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 + +function apply_scaling(freqs::AbstractVector; scale_factor=8) + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = similar(freqs) + for (i, freq) in enumerate(freqs) + wavelen = 2 * π / freq + if wavelen < high_freq_wavelen + new_freqs[i] = freq + elseif wavelen > low_freq_wavelen + new_freqs[i] = freq / scale_factor + else + @assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor) + new_freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq + end + end + return new_freqs +end + +function precompute_freqs_cis(dim::Int, end_pos::Int; + theta::T=10000f0, use_scaled=true, scale_factor=8) where T + freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) + if use_scaled + freqs = apply_scaling(freqs; scale_factor=scale_factor) + end + freqs_complex = cis.(T.(0:end_pos-1) * freqs') + cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) + sin = permutedims(imag(freqs_complex), (2, 1)) + cos = reshape(cos, (dim÷2, end_pos, 1, 1)) + sin = reshape(sin, (dim÷2, end_pos, 1, 1)) + return cos, sin +end + + +#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +#Use this one if you're using the Hugging Face weights. +function apply_rotary_emb(x, freqs_cis) + # x is (head_dim, seq_len, n_heads, batch) + head_dim, seq_len, n_heads, batch = size(x) + x1 = x[1:head_dim÷2, :, :, :] + x2 = x[head_dim÷2+1:end, :, :, :] + cos, sin = freqs_cis + out = vcat( + x1 .* cos[:,1:seq_len,:] .- x2 .* sin[:,1:seq_len,:], + x2 .* cos[:,1:seq_len,:] .+ x1 .* sin[:,1:seq_len,:] + ) + return out +end + + +struct FeedForward + w1::Dense + w2::Dense + w3::Dense +end + +function FeedForward(dim::Int, ff_hidden_dim::Int) + FeedForward( + Dense(dim => ff_hidden_dim, bias=false), + Dense(ff_hidden_dim => dim, bias=false), + Dense(dim => ff_hidden_dim, bias=false) + ) +end + +function (ff::FeedForward)(x) + return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) +end + +Flux.@layer :expand FeedForward + +struct RMSNorm{T} + weight::AbstractVector{T} + eps::T +end + +function RMSNorm(dim::Int; eps::T=1f-5) where T + RMSNorm{T}(ones(T, dim), eps) +end + +function (norm::RMSNorm)(x) + rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) + return x .* (norm.weight ./ rms) +end + +Flux.@layer RMSNorm + +struct KVCache{T} + cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) + cache_v::AbstractArray{T, 4} +end + +function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + KVCache(cache_k, cache_v) +end + +function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) + seqlen = size(xk, 2) + cache.cache_k[:, (start_pos+1):(start_pos+seqlen), :, :] = xk + cache.cache_v[:, (start_pos+1):(start_pos+seqlen), :, :] = xv + return cache.cache_k[:, 1:(start_pos+seqlen), :, :], + cache.cache_v[:, 1:(start_pos+seqlen), :, :] +end + +function repeat_kv(x::AbstractArray, n_rep::Int) + # x is (head_dim, seq_len, n_kv_heads, batch) + # output should be (head_dim, seq_len, n_rep * n_kv_heads, batch) + if n_rep == 1 + return x + end + head_dim, seq_len, n_kv_heads, batch = size(x) + x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch)) + x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) + #Metal.@allowscalar x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) + return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch)) +end + +mutable struct Attention + wq::Dense + wk::Dense + wv::Dense + wo::Dense + n_heads::Int + n_kv_heads::Int + head_dim::Int + n_rep::Int + cache::Union{Nothing, KVCache} +end + +function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads) + head_dim = dim ÷ n_heads + n_rep = n_heads ÷ n_kv_heads + Attention( + Dense(dim => n_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(dim => n_kv_heads * head_dim, bias=false), + Dense(n_heads * head_dim => dim, bias=false), + n_heads, + n_kv_heads, + head_dim, + n_rep, + nothing + ) +end + +function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T + dim, seqlen, batch = size(x) + + # Project Q,K,V + xq = attn.wq(x) + xk = attn.wk(x) + xv = attn.wv(x) + + #Reshaping dim: 8, len: 3, batch: 2 + # to: 2, 3, 4, 2 + # RoPE input needs (head_dim, len, n_heads, batch) + + # Reshape to separate heads + xq = reshape(xq, (attn.head_dim, attn.n_heads, seqlen, batch)) + xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) + + #Can we switch to keeping this in its shape, aplpying rot emb in that shape, and only permuting one + #xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster... + #xk = PermutedDimsArray(xk, (1,3,2,4)) + #xv = PermutedDimsArray(xv, (1,3,2,4)) + xq = permutedims(xq, (1,3,2,4)) + xk = permutedims(xk, (1,3,2,4)) + xv = permutedims(xv, (1,3,2,4)) + + # Apply RoPE + xq_rope = apply_rotary_emb(xq, freqs_cis) + xk_rope = apply_rotary_emb(xk, freqs_cis) + # Handle KV cache + if !isnothing(attn.cache) + xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) + end + # Apply GQA via repeat_kv + xk_rope = repeat_kv(xk_rope, attn.n_rep) + xv = repeat_kv(xv, attn.n_rep) + # Reshape for attention - dummy dim is seqlength, which isn't the length of the seq when using the KV cache + xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) + xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) + xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) + # Compute attention scores + + scores = batched_mul( + 0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) + 0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads) + ) ./ sqrt(T(attn.head_dim)) + if !isnothing(mask) + #@show typeof(scores) + #@show typeof(mask) + scores = scores .+ mask + end + #len: 3, len: 3, headsxbatch: 8 + sm_scores = softmax(scores; dims=2) # Need to get this over dim 1 for efficiency! + #len: 3, head_dim: 2, headsxbatch: 8 + output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) + # Reshape back to separate batch and heads + e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) + p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) + r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) + proj = attn.wo(r_output) + return proj +end + +Flux.@layer :expand Attention + +struct TransformerBlock + attention::Attention + feed_forward::FeedForward + attention_norm::RMSNorm + ffn_norm::RMSNorm +end + +function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; + norm_eps=1f-5) + TransformerBlock( + Attention(dim, n_heads, n_kv_heads), + FeedForward(dim, ff_hidden_dim), + RMSNorm(dim, eps=norm_eps), + RMSNorm(dim, eps=norm_eps) + ) +end + +function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) + h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) + out = h + block.feed_forward(block.ffn_norm(h)) + return out +end + +Flux.@layer TransformerBlock + +struct Transformer{T} + tok_embeddings::Flux.Embedding + layers::AbstractVector{TransformerBlock} + norm::RMSNorm{T} + output::Dense + freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} +end + +function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, + n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; + norm_eps::T=1f-5, + rope_theta::T=500000f0, + use_scaled_rope=false, + scale_factor=8) where T + + tok_embeddings = Flux.Embedding(vocab_size => dim) + layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers] + norm = RMSNorm(dim, eps=norm_eps) + output = Dense(dim => vocab_size, bias=false) + freqs_cis = precompute_freqs_cis( + dim ÷ n_heads, + max_seq_len * 2; + theta=rope_theta, + use_scaled=use_scaled_rope, + scale_factor=scale_factor + ) + Transformer(tok_embeddings, layers, norm, output, freqs_cis) +end + +Flux.@layer :expand Transformer trainable=(layers, norm) + + +function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, start_pos::Int) where T + seqlen = size(tokens, 1) # tokens expected as (seq_len, batch) + h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) + + # Get relevant freqs_cis slice + cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) + freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) + + mask = nothing + if seqlen > 1 + #mask = fill(T(-Inf), (seqlen, seqlen)) + mask = similar(h, seqlen, seqlen) + mask .= T(-Inf) + mask = triu(mask, 1) + # Add zeros for cached sequence + if start_pos > 0 + mask = hcat(zeros(T, seqlen, start_pos), mask) + end + end + for layer in model.layers + h = layer(h, start_pos, freqs_cis, mask) + end + h = model.norm(h) + output = model.output(h) + return output +end + +function forward_loss(model::Transformer{T}, inputs::AbstractArray, + targets::AbstractArray; ignore_index::Int=-100, + mask = triu(fill(T(-Inf), (size(inputs, 1), size(inputs, 1))),1)) where T + seqlen = size(inputs, 1) #(seq_len, batch) + h = model.tok_embeddings(inputs) # (dim, seq_len, batch) + + cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) + freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) + + # Forward through layers (start_pos = 0 disables KV caching) + for layer in model.layers + h = layer(h, 0, freqs_cis, mask) + end + h = model.norm(h) + logits = model.output(h) + #@show size(logits) + # Need to reshape to (vocab_size, seq_len * batch) + logits_2d = reshape(logits, size(logits,1), :) + #@show [argmax(logits_2d[:,i]) for i in 1:size(logits_2d,2)] + #@show size(logits_2d) + targets_1d = reshape(targets, :) + #@show size(targets_1d) + # Mask out ignored indices - will handle this later. + # Note: this is not the autoregressive mask, but the mask for the loss function + #= + mask = targets_1d .!= ignore_index + if any(mask) + loss = Flux.logitcrossentropy( + logits_2d[:, mask], + targets_1d[mask] + ) + else + loss = zero(Float32) + end + =# + vocab_size = size(model.tok_embeddings.weight, 2) + gt = Flux.onehotbatch(targets_1d, 1:vocab_size) + #@show size(gt) + loss = Flux.logitcrossentropy(logits_2d, gt) + #@show Flux.logitcrossentropy(logits_2d, gt, agg = identity) + return loss +end + + + + + + + +#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +#= +#Use this one if you're using the original Meta weights. +#You'll need to change the type of the freqs_cis field in Transformer to match. +function precompute_freqs_cis(dim::Int, end_pos::Int; + theta::Float32=10000f0, + use_scaled::Bool=true, scale_factor::Int=8) + # Create frequencies for the first half of dimensions + freqs = 1f0 ./ (theta .^ (Float32.(0:2:dim-1)[1:dim÷2] ./ dim)) + # Create position indices - note, using 0 indexing here because python consistency. Not sure if it makes a difference. + t = Float32.(0:end_pos-1) + if use_scaled + freqs = apply_scaling(freqs; scale_factor=scale_factor) + end + # Compute outer product + freqs = t * freqs' + # Convert to complex exponentials + # Note: Julia's cis(x) = exp(ix) = cos(x) + i*sin(x) + freqs_complex = cis.(freqs) + # Stack real and imaginary parts + # Note: Julia's reshape is similar to PyTorch's stack + freqs_cis_real = reshape( + reinterpret(Float32, reshape(freqs_complex, :)), + (2, size(freqs)...) + ) + # Permute to match PyTorch's dimension ordering + return permutedims(freqs_cis_real, (2,3,1)) +end + +function apply_rotary_emb(x, freqs_cis) + # x is (head_dim, seq_len, n_heads, batch) in Julia + # freqs_cis is (seq_len, head_dim/2, 2) + + #@show size(freqs_cis) + + # Reshape x to separate real/imaginary pairs + head_dim, seq_len, n_heads, batch = size(x) + x_reshaped = reshape(x, (2, head_dim÷2, seq_len, n_heads, batch)) + + # Reshape freqs_cis to broadcast correctly + # Note: reshape to (2, head_dim/2, seq_len, 1, 1) for broadcasting + freqs_cis = permutedims(freqs_cis, (3, 2, 1)) # now (2, head_dim/2, seq_len) + freqs_cis = reshape(freqs_cis, (2, size(freqs_cis, 2), size(freqs_cis, 3), 1, 1)) + + # Apply rotation using complex multiplication formula: + # (a + bi)(c + di) = (ac-bd) + (ad+bc)i + x_real = x_reshaped[1:1, :, :, :, :] + x_imag = x_reshaped[2:2, :, :, :, :] + f_real = freqs_cis[1:1, :, :, :, :] + f_imag = freqs_cis[2:2, :, :, :, :] + + #@show size(f_real) + #@show size(f_imag) + + #This is for checking the freqs_cis. + #Note: the cos, sin values are repeated in python + #g(f_real, f_imag) #passes + + out_real = x_real .* f_real .- x_imag .* f_imag + out_imag = x_imag .* f_real .+ x_real .* f_imag + + # Combine and reshape back + out = vcat(out_real, out_imag) + return reshape(out, (head_dim, seq_len, n_heads, batch)) +end +=# diff --git a/src/sampling.jl b/src/sampling.jl new file mode 100644 index 0000000..48971e0 --- /dev/null +++ b/src/sampling.jl @@ -0,0 +1,63 @@ + +function default_sampler(logits::AbstractVector) + return argmax(logits) +end + +function generate(model::Transformer{T}, + initial_tokens::AbstractArray{IntT}; + max_new_tokens=100, + sampler::Function=default_sampler, + encoder_for_printing = nothing, + end_token = 128010) where {T, IntT} + + # Initialize sequence with a new copy of the tokens + current_len = length(initial_tokens) + tokens = Vector{IntT}(undef, current_len + max_new_tokens) + tokens[1:current_len] = initial_tokens + # Set up KV caches for all attention layers + for layer in model.layers + layer.attention.cache = KVCache( + T, # eltype + 1, # batch_size + current_len + max_new_tokens, # max possible sequence length + layer.attention.n_kv_heads, + layer.attention.head_dim + ) + end + # Process the initial sequence + if current_len > 0 + input_tokens = reshape(initial_tokens, :, 1) # (seq_len, batch=1) + logits = forward_inference(model, input_tokens, 0) + start_pos = current_len + else + start_pos = 0 + end + # Generate new tokens one at a time + for _ in 1:max_new_tokens + # If sequence is empty or we want to process just the last token + if start_pos == 0 + input_tokens = reshape([128001], :, 1) # Use start of text token if empty + else + input_tokens = reshape([tokens[current_len]], :, 1) # Just the last token + end + # Get logits for next token + logits = forward_inference(model, input_tokens, start_pos) + # Sample next token (logits are size vocab × 1 × 1) + next_token = sampler(vec(logits[:, end, 1])) + current_len += 1 + tokens[current_len] = next_token + if !isnothing(encoder_for_printing) + print(encoder_for_printing.decode([next_token])) + end + if next_token == end_token + break + end + start_pos += 1 + end + # Clear KV caches + for layer in model.layers + layer.attention.cache = nothing + end + return tokens[1:current_len] +end + diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..f93046f --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,133 @@ +llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base") + +""" +Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt. + + tkn = llama3_tokenizer() + prompt = assistant_prompt("What is the capital of France?", tkn) + generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) +""" +assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn); + + +""" +Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. + + tkn = llama3_tokenizer() + prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn) + generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) +""" +function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) + #begin_of_text, start_header_id + prompt = [128001, 128007] #plus 1 because Julia is 1-indexed + prompt = vcat(prompt, tokenizer.encode("system")) + push!(prompt, 128008) #end_header_id + prompt = vcat(prompt, tokenizer.encode(sys_prompt)) + prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id + prompt = vcat(prompt, tokenizer.encode("user")) + push!(prompt, 128008) #end_header_id + prompt = vcat(prompt, tokenizer.encode("\n")) + prompt = vcat(prompt, tokenizer.encode(user_prompt)) + prompt = vcat(prompt, [128009, 128007]) #eot_id, start_header_id + prompt = vcat(prompt, tokenizer.encode("assistant")) + push!(prompt, 128008) #end_header_id + return prompt +end + + +""" +Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file. +Important note: Huggingface uses a different RoPE convention than other implementations, +so if you're loading weights from a different source, you might get very poor model performance. + + using JSON3 + config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String)) + model_weight_paths = ["Llama3_2_1B_instruct/model.safetensors"] #Can be an array of paths if the model is split across multiple files + model = load_llama3_from_safetensors(model_weight_paths, config) +""" +function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32) + config = Dict(config) #Just in case the user passed eg. a JSON3.Object + @assert config[:rope_scaling][:rope_type] == "llama3" + @assert config[:rope_scaling][:low_freq_factor] == 1 + @assert config[:rope_scaling][:high_freq_factor] == 4 + @assert config[:rope_scaling][:original_max_position_embeddings] == 8192 + + # Create model with config parameters from the JSON + model = Transformer( + config[:vocab_size], # vocab_size + config[:hidden_size], # dim (hidden_size) + config[:num_hidden_layers], # n_layers (num_hidden_layers) + config[:num_attention_heads], # n_heads (num_attention_heads) + config[:num_key_value_heads], # n_kv_heads (num_key_value_heads) + config[:max_position_embeddings], # max_seq_len (max_position_embeddings) + config[:intermediate_size], # ff_hidden_dim + norm_eps=T(config[:rms_norm_eps]), # rms_norm_eps + rope_theta=T(config[:rope_theta]), # rope_theta + use_scaled_rope=true, # Using scaled RoPE based on the config + scale_factor=config[:rope_scaling][:factor] # scale_factor + ) + + for path in paths # Process one file at a time + weights = load_safetensors(path) + + if haskey(weights, "model.embed_tokens.weight") + model.tok_embeddings.weight .= weights["model.embed_tokens.weight"]' + if config[:tie_word_embeddings] + model.output.weight .= weights["model.embed_tokens.weight"] + end + end + if !config[:tie_word_embeddings] + if haskey(weights, "lm_head.weight") + model.output.weight .= weights["lm_head.weight"] + else + error("tie_word_embeddings was true, but lm_head.weight was present.") + end + end + if haskey(weights, "model.norm.weight") + model.norm.weight .= weights["model.norm.weight"] + end + + n_layers = length(model.layers) + for i in 0:(n_layers-1) + prefix = "model.layers.$i" + layer = model.layers[i+1] + + if haskey(weights, "$prefix.self_attn.q_proj.weight") + layer.attention.wq.weight .= weights["$prefix.self_attn.q_proj.weight"] + end + if haskey(weights, "$prefix.self_attn.k_proj.weight") + layer.attention.wk.weight .= weights["$prefix.self_attn.k_proj.weight"] + end + if haskey(weights, "$prefix.self_attn.v_proj.weight") + layer.attention.wv.weight .= weights["$prefix.self_attn.v_proj.weight"] + end + if haskey(weights, "$prefix.self_attn.o_proj.weight") + layer.attention.wo.weight .= weights["$prefix.self_attn.o_proj.weight"] + end + + if haskey(weights, "$prefix.mlp.gate_proj.weight") + layer.feed_forward.w1.weight .= weights["$prefix.mlp.gate_proj.weight"] + end + if haskey(weights, "$prefix.mlp.down_proj.weight") + layer.feed_forward.w2.weight .= weights["$prefix.mlp.down_proj.weight"] + end + if haskey(weights, "$prefix.mlp.up_proj.weight") + layer.feed_forward.w3.weight .= weights["$prefix.mlp.up_proj.weight"] + end + + if haskey(weights, "$prefix.input_layernorm.weight") + layer.attention_norm.weight .= weights["$prefix.input_layernorm.weight"] + end + if haskey(weights, "$prefix.post_attention_layernorm.weight") + layer.ffn_norm.weight .= weights["$prefix.post_attention_layernorm.weight"] + end + end + + weights = nothing + GC.gc() + end + + return model +end + +load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T)