Skip to content

Commit

Permalink
Implement IPARoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
billera committed Dec 13, 2024
1 parent 47ffa01 commit b172b21
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/InvariantPointAttention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ using ChainRulesCore

include("rotational_utils.jl")
include("grads.jl")
include("rope.jl")
include("layers.jl")
include("masks.jl")

export IPA_settings
export IPA, IPCrossA
export IPARoPE
export IPAStructureModuleLayer, IPCrossAStructureModuleLayer
export BackboneUpdate
export right_to_left_mask
Expand Down
26 changes: 16 additions & 10 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ IPA_settings(
pairwise = c_z > 0,
use_softmax1,
scaling_qk,
rope,
)


Expand All @@ -74,7 +75,7 @@ end

Flux.@layer IPCrossA # provides parameter collection, gpu movement and more

function IPCrossA(settings::NamedTuple)
function IPCrossA(settings::NamedTuple; rope_kwargs...)
dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = settings
# Needs a slightly unusual initialization - hat-tip: Kenta
init = Flux.kaiming_uniform(gain = 1.0f0)
Expand Down Expand Up @@ -132,11 +133,14 @@ function (ipa::Union{IPA, IPCrossA})(T::Tuple{AbstractArray,AbstractArray}, S::A
end

#Attention props from L (Keys, Values) to R (Queries).
#Because IPA uses Q'K, our pairwise matrices are R-by-L
#Because IPA uses Q'K, our pairwise matrices are R-by-L
#Rope is an IPARoPE, applying the usual RoPE to queries and keys pertaining to the same chains and a fixed rotation to queries and keys pertaining to different chains.
#Chain diffs defaults to 1, meaning everything is in the same chain. Otherwise, a pairwise matrix where 1 denotes the same chain, 0 denotes different chains should be used.
function (ipa::Union{IPCrossA, IPA})(
TiL::Tuple{AbstractArray, AbstractArray}, siL::AbstractArray,
TiR::Tuple{AbstractArray, AbstractArray}, siR::AbstractArray;
zij = nothing, mask = 0, customgrad = true,
zij = nothing, mask = 0, customgrad = true,
rope::Union{IPARoPE, Nothing} = nothing, chain_diffs = 1,
)
if isnothing(zij) || mask == 0 || siL != siR || TiL != TiR
@warn "Forcing customgrad to false"
Expand Down Expand Up @@ -179,6 +183,13 @@ function (ipa::Union{IPCrossA, IPA})(

qh = reshape(l.proj_qh(siR),(c,N_head,N_frames_R,:))
kh = reshape(l.proj_kh(siL),(c,N_head,N_frames_L,:))

if !isnothing(rope)
qhTkh = dotproducts(rope, qh, kh; chain_diffs)
else
qhTkh = dotproducts(qh, kh)
end

vh = reshape(l.proj_vh(siL),(c,N_head,N_frames_L,:))

if isnothing(l.scale_h)
Expand All @@ -194,14 +205,9 @@ function (ipa::Union{IPCrossA, IPA})(
# This should be Q'K, following IPA, which isn't like the regular QK'
# Dot products between queries and keys.
#FramesR, c, N_head, Batch
qhT = permutedims(qh, (3, 1, 2, 4))
#c, FramesL, N_head, Batch
kh = permutedims(kh, (1, 3, 2, 4))
qhTkh = permutedims(#FramesR, #FramesL, N_head, Batch
batched_mul(qhT,kh)
#N_head, FramesR, FramesL, Batch when we use (3,1,2,4)
,(3,1,2,4))



# Applying our transformations to the queries, keys, and values to put them in the global frame.
Tqhp = reshape(T_R3(qhp, rot_TiR,translate_TiR),3,N_head,N_query_points,N_frames_R,:)
Tkhp = reshape(T_R3(khp, rot_TiL,translate_TiL),3,N_head,N_query_points,N_frames_L,:)
Expand Down
165 changes: 165 additions & 0 deletions src/rope.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
struct RoPE{A<:AbstractArray}
cos::A
sin::A
end
Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

Flux.@layer RoPE trainable=()

function apply_scaling!(freqs::AbstractVector; scale_factor=8)
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
for (i, freq) in enumerate(freqs)
wavelen = 2π / freq
if wavelen > low_freq_wavelen
freqs[i] = freq / scale_factor
elseif wavelen > high_freq_wavelen
@assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) /
(high_freq_factor - low_freq_factor)
freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq
end
end
return freqs
end

function RoPE(
dim::Int, end_pos::Int;
theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0
) where T
freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
use_scaled && apply_scaling!(freqs; scale_factor)
freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs')
cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len)
sin = permutedims(imag(freqs_complex), (2, 1))
cos = reshape(cos, (dim÷2, end_pos - start_pos, 1))
sin = reshape(sin, (dim÷2, end_pos - start_pos, 1))
return RoPE(cos, sin)
end
# Note about Huggingface weights and rotary embeddings:
# https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509
# Use this one if you're using the Hugging Face weights.
function (rope::RoPE)(x)
head_dim = size(x, 1)
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]
return vcat(
x1 .* rope.cos .- x2 .* rope.sin,
x2 .* rope.cos .+ x1 .* rope.sin
)
end

function unrope(rope, x)
head_dim = size(x, 2)
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]
return vcat(
x1 .* rope.cos .+ x2 .* rope.sin,
x2 .* rope.cos .- x1 .* rope.sin
)
end

struct FixedRoPE
freqs::AbstractArray
cos::AbstractArray
sin::AbstractArray
end

Flux.@layer FixedRoPE trainable=(:freqs)

function FixedRoPE(dim::Int; theta::T=10000f0) where T
# Create a single fixed rotation
freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
# Only generate for a single position
freqs_complex = cis.(freqs)
cos = real(freqs_complex) # (head_dim/2)
sin = imag(freqs_complex)
# Reshape to (head_dim/2, 1, 1, 1) for broadcasting
cos = reshape(cos, (dim÷2, 1, 1, 1))
sin = reshape(sin, (dim÷2, 1, 1, 1))
return FixedRoPE(freqs, cos, sin)
end

# Apply fixed rotation to queries only
function (rope::FixedRoPE)(x)
head_dim = size(x, 1)
# Split queries into two parts
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]

# Rotate queries
rotx = vcat(
x1 .* rope.cos .- x2 .* rope.sin,
x2 .* rope.cos .+ x1 .* rope.sin
)
return rotx
end



# Apply fixed un-rotation to queries only
function unapply_fixed_rope(rope::FixedRoPE, rotated_queries)
head_dim = size(rotated_queries, 1)
# Split rotated queries into two parts
q1 = rotated_queries[1:head_dim÷2, :, :, :]
q2 = rotated_queries[head_dim÷2+1:end, :, :, :]

# Un-rotate queries (inverse rotation)
unrotated_queries = vcat(
q1 .* rope.cos .+ q2 .* rope.sin, # Note: + instead of -
q2 .* rope.cos .- q1 .* rope.sin # Note: - instead of +
)

return unrotated_queries
end

struct IPARoPE
rope::RoPE
fixed_rope::FixedRoPE
end
Base.getindex(rope::IPARoPE, i) = IPARoPE(rope.rope[i], rope.fixed_rope)

function IPARoPE(dim::Int, end_pos::Int;
theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0) where T
return IPARoPE(
RoPE(dim, end_pos; theta, use_scaled, scale_factor, start_pos),
FixedRoPE(dim; theta)
)
end

function dotproducts(qh::AbstractArray{T, 4}, kh::AbstractArray{T, 4}) where T<: Real
qhT = permutedims(qh, (3, 1, 2, 4))
#c, FramesL, N_head, Batch
kh = permutedims(kh, (1, 3, 2, 4))
qhTkh = permutedims(#FramesR, #FramesL, N_head, Batch
batched_mul(qhT,kh)
#N_head, FramesR, FramesL, Batch when we use (3,1,2,4)
,(3,1,2,4))
return qhTkh
end

"""
function RoPEdotproducts(iparope::IPARoPE, q, k; chain_diffs = nothing)
chain_diffs is either nothing or a array of 0's and 1's describing the ij-pair as pertaining to the same chain if the entry at ij is 1, else 0.
"""
function dotproducts(iparope::IPARoPE, qh::AbstractArray{T, 4}, kh::AbstractArray{T, 4}; chain_diffs = nothing) where T<: Real
# O(N) permutedims, shouldn't be too bad.
qropshape = permutedims(qh, (1,3,2,4))
kropshape = permutedims(kh, (1,3,2,4))
rotq, rotk = permutedims(iparope.rope(qropshape), (1,3,2,4)), permutedims(iparope.rope(kropshape), (1,3,2,4))
rotqTrotk = dotproducts(rotq, rotk)
# when things are from different chain, we rotate only the queries by a fixed amount
if !isnothing(chain_diffs)
rotq2 = permutedims(iparope.fixed_rope(qropshape), (1,3,2,4))
rotq2Trotk2 = dotproducts(rotq2, kh)
# unsqueeze chain diffs to shape 1, framesR, framesL
rotqTrotk = unsqueeze(chain_diffs, 1) .* rotqTrotk .+ (1 .- unsqueeze(chain_diffs, 1)) .* rotq2Trotk2
end
return rotqTrotk
end

Flux.@layer IPARoPE

0 comments on commit b172b21

Please sign in to comment.