diff --git a/Project.toml b/Project.toml index 3e82388..2641079 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = ["murrellb and contributors"] version = "1.1.0-DEV" [deps] -Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -23,19 +23,19 @@ HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/Hu MetalExt = "Metal" [compat] -Accessors = "0.1.38" +ChainRulesCore = "1.25.0" Flux = "0.14, 0.15" -LogitSamplers = "0.1" -LowRankLayers = "0.1" +LogitSamplers = "0.1.0" +LowRankLayers = "0.1.2" Metal = "1" NNlib = "0.9" SafeTensors = "1" julia = "1.11" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Downloads", "JSON3"] diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 930cabb..af04e8d 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -6,6 +6,7 @@ using LinearAlgebra using NNlib using LogitSamplers using LowRankLayers +#using ChainRulesCore using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer @@ -25,6 +26,8 @@ export Transformer export unrope export rerope_cache! +#include("sdpa.jl") + include("model.jl") export forward_loss export forward_inference diff --git a/src/layers.jl b/src/layers.jl index 7b2ff9f..fca2f81 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -101,16 +101,16 @@ function unrope(rope, x) ) end -struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache} - wq::Q - wk::K - wv::V - wo::O +mutable struct Attention + wq::AnyDense + wk::AnyDense + wv::AnyDense + wo::AnyDense dim::Int n_heads::Int n_kv_heads::Int head_dim::Int - cache::C + cache::KVCache end Flux.@layer Attention trainable=(wq,wv) diff --git a/src/model.jl b/src/model.jl index ec8f96a..dbbebdb 100644 --- a/src/model.jl +++ b/src/model.jl @@ -91,4 +91,4 @@ end # compat forward_inference(model, args...) = model(args...) -forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(forward(model, inputs; clear_cache = clear_cache), targets; loss_mask = loss_mask) +forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(model(inputs, clear_cache = clear_cache), targets, loss_mask = loss_mask) diff --git a/src/utils.jl b/src/utils.jl index 04ad197..e6b76f2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,3 @@ -using Accessors - encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) @@ -162,37 +160,37 @@ function load_llama3_from_safetensors( if !isempty(add_lora_to) if :Q in add_lora_to for layer in model.layers - @reset layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) + layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) end end if :K in add_lora_to for layer in model.layers - @reset layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) + layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) end end if :V in add_lora_to for layer in model.layers - @reset layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) + layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) end end if :O in add_lora_to for layer in model.layers - @reset layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) + layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) end end if :w1 in add_lora_to for layer in model.layers - @reset layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) + layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) end end if :w2 in add_lora_to for layer in model.layers - @reset layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) + layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) end end if :w3 in add_lora_to for layer in model.layers - @reset layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) + layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) end end end