Skip to content

Commit

Permalink
Various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Dec 26, 2024
1 parent f60900f commit e2bac27
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 21 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["murrellb <[email protected]> and contributors"]
version = "1.1.0-DEV"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,19 +23,19 @@ HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/Hu
MetalExt = "Metal"

[compat]
Accessors = "0.1.38"
ChainRulesCore = "1.25.0"
Flux = "0.14, 0.15"
LogitSamplers = "0.1"
LowRankLayers = "0.1"
LogitSamplers = "0.1.0"
LowRankLayers = "0.1.2"
Metal = "1"
NNlib = "0.9"
SafeTensors = "1"
julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Downloads", "JSON3"]
3 changes: 3 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using LinearAlgebra
using NNlib
using LogitSamplers
using LowRankLayers
#using ChainRulesCore

using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer

Expand All @@ -25,6 +26,8 @@ export Transformer
export unrope
export rerope_cache!

#include("sdpa.jl")

include("model.jl")
export forward_loss
export forward_inference
Expand Down
12 changes: 6 additions & 6 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ function unrope(rope, x)
)
end

struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache}
wq::Q
wk::K
wv::V
wo::O
mutable struct Attention
wq::AnyDense
wk::AnyDense
wv::AnyDense
wo::AnyDense
dim::Int
n_heads::Int
n_kv_heads::Int
head_dim::Int
cache::C
cache::KVCache
end

Flux.@layer Attention trainable=(wq,wv)
Expand Down
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ end

# compat
forward_inference(model, args...) = model(args...)
forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(forward(model, inputs; clear_cache = clear_cache), targets; loss_mask = loss_mask)
forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) = loss(model(inputs, clear_cache = clear_cache), targets, loss_mask = loss_mask)

Check warning on line 94 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L94

Added line #L94 was not covered by tests
16 changes: 7 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Accessors

encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1
decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...)

Expand Down Expand Up @@ -162,37 +160,37 @@ function load_llama3_from_safetensors(
if !isempty(add_lora_to)
if :Q in add_lora_to
for layer in model.layers
@reset layer.attention.wq = LoRADense(layer.attention.wq, lora_dim)
layer.attention.wq = LoRADense(layer.attention.wq, lora_dim)

Check warning on line 163 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L163

Added line #L163 was not covered by tests
end
end
if :K in add_lora_to
for layer in model.layers
@reset layer.attention.wk = LoRADense(layer.attention.wk, lora_dim)
layer.attention.wk = LoRADense(layer.attention.wk, lora_dim)

Check warning on line 168 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L168

Added line #L168 was not covered by tests
end
end
if :V in add_lora_to
for layer in model.layers
@reset layer.attention.wv = LoRADense(layer.attention.wv, lora_dim)
layer.attention.wv = LoRADense(layer.attention.wv, lora_dim)

Check warning on line 173 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L173

Added line #L173 was not covered by tests
end
end
if :O in add_lora_to
for layer in model.layers
@reset layer.attention.wo = LoRADense(layer.attention.wo, lora_dim)
layer.attention.wo = LoRADense(layer.attention.wo, lora_dim)

Check warning on line 178 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L178

Added line #L178 was not covered by tests
end
end
if :w1 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim)
layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim)

Check warning on line 183 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L183

Added line #L183 was not covered by tests
end
end
if :w2 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim)
layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim)

Check warning on line 188 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L188

Added line #L188 was not covered by tests
end
end
if :w3 in add_lora_to
for layer in model.layers
@reset layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim)
layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim)

Check warning on line 193 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L193

Added line #L193 was not covered by tests
end
end
end
Expand Down

0 comments on commit e2bac27

Please sign in to comment.