Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Dec 13, 2024
1 parent b172b21 commit d51a1d7
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 22 deletions.
1 change: 0 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ IPA_settings(
pairwise = c_z > 0,
use_softmax1,
scaling_qk,
rope,
)


Expand Down
23 changes: 2 additions & 21 deletions src/rope.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# taking from https://github.com/MurrellGroup/Jjama3.jl/blob/main/src/model.jl
# todo: re-write to avoid the O(N) permutedims.
struct RoPE{A<:AbstractArray}
cos::A
sin::A
Expand Down Expand Up @@ -86,36 +88,15 @@ end
# Apply fixed rotation to queries only
function (rope::FixedRoPE)(x)
head_dim = size(x, 1)
# Split queries into two parts
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]

# Rotate queries
rotx = vcat(
x1 .* rope.cos .- x2 .* rope.sin,
x2 .* rope.cos .+ x1 .* rope.sin
)
return rotx
end



# Apply fixed un-rotation to queries only
function unapply_fixed_rope(rope::FixedRoPE, rotated_queries)
head_dim = size(rotated_queries, 1)
# Split rotated queries into two parts
q1 = rotated_queries[1:head_dim÷2, :, :, :]
q2 = rotated_queries[head_dim÷2+1:end, :, :, :]

# Un-rotate queries (inverse rotation)
unrotated_queries = vcat(
q1 .* rope.cos .+ q2 .* rope.sin, # Note: + instead of -
q2 .* rope.cos .- q1 .* rope.sin # Note: - instead of +
)

return unrotated_queries
end

struct IPARoPE
rope::RoPE
fixed_rope::FixedRoPE
Expand Down

0 comments on commit d51a1d7

Please sign in to comment.