Skip to content

Commit

Permalink
Fix array mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed May 22, 2024
1 parent 97de3ad commit e0ebb88
Showing 1 changed file with 23 additions and 16 deletions.
39 changes: 23 additions & 16 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,16 @@ function (ipa::Union{IPCrossA, IPA})(
qh = reshape(l.proj_qh(siR),(c,N_head,N_frames_R,:))
kh = reshape(l.proj_kh(siL),(c,N_head,N_frames_L,:))
vh = reshape(l.proj_vh(siL),(c,N_head,N_frames_L,:))
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:))
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:))
vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

if !isnothing(l.scale_h)
if isnothing(l.scale_h)
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:))
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:))
else
scale_h = reshape(l.scale_h, (1,N_head*N_query_points,1,1))
qhp .*= scale_h
khp .*= scale_h
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:)) .* scale_h
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:)) .* scale_h
end
vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

# This should be Q'K, following IPA, which isn't like the regular QK'
# Dot products between queries and keys.
Expand Down Expand Up @@ -285,16 +286,18 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr
qh = reshape(l.proj_qh(siR),(c,N_head,N_frames_R,:))
kh = reshape(l.proj_kh(siL),(c,N_head,N_frames_L,:))
vh = reshape(l.proj_vh(siL),(c,N_head,N_frames_L,:))
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:))
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:))
vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

if !isnothing(l.scale_h)
if isnothing(l.scale_h)
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:))
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:))
else
scale_h = reshape(l.scale_h, (1,N_head*N_query_points,1,1))
qhp .*= scale_h
khp .*= scale_h
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:)) .* scale_h
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:)) .* scale_h
end

vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

# This should be Q'K, following IPA, which isn't like the regular QK'
# Dot products between queries and keys.
#FramesR, c, N_head, Batch
Expand Down Expand Up @@ -483,14 +486,18 @@ function expand(

Δqhp = reshape(calldense(layer.proj_qhp, siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B))
Δkhp = reshape(calldense(layer.proj_khp, siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B))
Δvhp = reshape(calldense(layer.proj_vhp, siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B))

if !isnothing(layer.scale_h)
if isnothing(layer.scale_h)
Δqhp = reshape(calldense(layer.proj_qhp, siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B))
Δkhp = reshape(calldense(layer.proj_khp, siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B))
else
scale_h = reshape(layer.scale_h, (1,N_head*N_query_points,1,1))
Δqhp .*= scale_h
Δkhp .*= scale_h
Δqhp = reshape(calldense(layer.proj_qhp, siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B)) .* scale_h
Δkhp = reshape(calldense(layer.proj_khp, siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B)) .* scale_h
end

Δvhp = reshape(calldense(layer.proj_vhp, siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B))

kh = cat(cache.kh, Δkh, dims = 3)
vh = cat(cache.vh, Δvh, dims = 3)

Expand Down

0 comments on commit e0ebb88

Please sign in to comment.