From d51a1d79d6e3fe13ec3f9e2852d73a9884e59e81 Mon Sep 17 00:00:00 2001 From: billera Date: Fri, 13 Dec 2024 12:12:38 +0100 Subject: [PATCH] fix --- src/layers.jl | 1 - src/rope.jl | 23 ++--------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 3f93126..6598829 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -57,7 +57,6 @@ IPA_settings( pairwise = c_z > 0, use_softmax1, scaling_qk, - rope, ) diff --git a/src/rope.jl b/src/rope.jl index 9e94303..b371147 100644 --- a/src/rope.jl +++ b/src/rope.jl @@ -1,3 +1,5 @@ +# taking from https://github.com/MurrellGroup/Jjama3.jl/blob/main/src/model.jl +# todo: re-write to avoid the O(N) permutedims. struct RoPE{A<:AbstractArray} cos::A sin::A @@ -86,11 +88,8 @@ end # Apply fixed rotation to queries only function (rope::FixedRoPE)(x) head_dim = size(x, 1) - # Split queries into two parts x1 = x[1:head_dim÷2, :, :, :] x2 = x[head_dim÷2+1:end, :, :, :] - - # Rotate queries rotx = vcat( x1 .* rope.cos .- x2 .* rope.sin, x2 .* rope.cos .+ x1 .* rope.sin @@ -98,24 +97,6 @@ function (rope::FixedRoPE)(x) return rotx end - - -# Apply fixed un-rotation to queries only -function unapply_fixed_rope(rope::FixedRoPE, rotated_queries) - head_dim = size(rotated_queries, 1) - # Split rotated queries into two parts - q1 = rotated_queries[1:head_dim÷2, :, :, :] - q2 = rotated_queries[head_dim÷2+1:end, :, :, :] - - # Un-rotate queries (inverse rotation) - unrotated_queries = vcat( - q1 .* rope.cos .+ q2 .* rope.sin, # Note: + instead of - - q2 .* rope.cos .- q1 .* rope.sin # Note: - instead of + - ) - - return unrotated_queries -end - struct IPARoPE rope::RoPE fixed_rope::FixedRoPE