Skip to content

Commit

Permalink
Qwen2!
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 29, 2024
1 parent 806f833 commit 3609b3f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ mutable struct Attention
cache::Union{Nothing, KVCache}
end

function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads)
function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false)
head_dim = dim ÷ n_heads
n_rep = n_heads ÷ n_kv_heads
Attention(
Dense(dim => n_heads * head_dim, bias=false),
Dense(dim => n_kv_heads * head_dim, bias=false),
Dense(dim => n_kv_heads * head_dim, bias=false),
Dense(dim => n_heads * head_dim, bias=qkv_bias),
Dense(dim => n_kv_heads * head_dim, bias=qkv_bias),
Dense(dim => n_kv_heads * head_dim, bias=qkv_bias),
Dense(n_heads * head_dim => dim, bias=false),
n_heads,
n_kv_heads,
Expand Down Expand Up @@ -146,9 +146,9 @@ struct TransformerBlock
end

function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim;
norm_eps=1f-5)
norm_eps=1f-5, qkv_bias=false)
TransformerBlock(
Attention(dim, n_heads, n_kv_heads),
Attention(dim, n_heads, n_kv_heads; qkv_bias),
FeedForward(dim, ff_hidden_dim),
RMSNorm(dim, eps=norm_eps),
RMSNorm(dim, eps=norm_eps)
Expand All @@ -174,12 +174,13 @@ end
function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int,
n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int;
norm_eps::T=1f-5,
qkv_bias=false,
rope_theta::T=500000f0,
use_scaled_rope=false,
scale_factor=8) where T

tok_embeddings = Flux.Embedding(vocab_size => dim)
layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps) for _ in 1:n_layers]
layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers]
norm = RMSNorm(dim, eps=norm_eps)
output = Dense(dim => vocab_size, bias=false)
freqs_cis = precompute_freqs_cis(
Expand Down
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32
scale_factor = config[:rope_scaling][:factor]
end
end
if config[:model_type] == "qwen2"
qkv_bias = true
else
qkv_bias = false
end
model = Transformer(
config[:vocab_size], # vocab_size
config[:hidden_size], # dim (hidden_size)
Expand All @@ -63,6 +68,7 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32
config[:num_key_value_heads], # n_kv_heads (num_key_value_heads)
config[:max_position_embeddings], # max_seq_len (max_position_embeddings)
config[:intermediate_size], # ff_hidden_dim
qkv_bias = qkv_bias, # qkv_bias
norm_eps=T(config[:rms_norm_eps]), # rms_norm_eps
rope_theta=T(config[:rope_theta]), # rope_theta
use_scaled_rope=true, # Using scaled RoPE based on the config
Expand Down Expand Up @@ -104,6 +110,16 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32
if haskey(weights, "$prefix.self_attn.o_proj.weight")
layer.attention.wo.weight .= weights["$prefix.self_attn.o_proj.weight"]
end

if haskey(weights, "$prefix.self_attn.q_proj.bias")
layer.attention.wq.bias .= weights["$prefix.self_attn.q_proj.bias"]
end
if haskey(weights, "$prefix.self_attn.k_proj.bias")
layer.attention.wk.bias .= weights["$prefix.self_attn.k_proj.bias"]
end
if haskey(weights, "$prefix.self_attn.v_proj.bias")
layer.attention.wv.bias .= weights["$prefix.self_attn.v_proj.bias"]
end

if haskey(weights, "$prefix.mlp.gate_proj.weight")
layer.feed_forward.w1.weight .= weights["$prefix.mlp.gate_proj.weight"]
Expand Down

0 comments on commit 3609b3f

Please sign in to comment.