Skip to content

Commit

Permalink
Merge pull request #8 from bicycle1885/expand
Browse files Browse the repository at this point in the history
Caching
  • Loading branch information
billera authored Mar 27, 2024
2 parents 869ebdb + 8649728 commit 1b6320a
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 1 deletion.
166 changes: 165 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function (ipa::Union{IPCrossA, IPA})(TiL::Tuple{AbstractArray,AbstractArray}, si
qhp = reshape(l.proj_qhp(siR),(3,N_head*N_query_points,N_frames_R,:))
khp = reshape(l.proj_khp(siL),(3,N_head*N_query_points,N_frames_L,:))
vhp = reshape(l.proj_vhp(siL),(3,N_head*N_point_values,N_frames_L,:))

# This should be Q'K, following IPA, which isn't like the regular QK'
# Dot products between queries and keys.
#FramesR, c, N_head, Batch
Expand Down Expand Up @@ -264,3 +264,167 @@ function (structuremodulelayer::Union{IPCrossAStructureModuleLayer, IPAStructure
T_R = l.backbone(T_R, S_R)
return T_R, S_R
end

struct IPACache
sizeL
sizeR
batchsize

# cached arrays
qh # channel × head × residues (R) × batch
kh # channel × head × residues (L) × batch
vh # channel × head × residues (L) × batch

qhp # 3 × {head × query points} × residues (R) × batch
khp # 3 × {head × query points} × residues (L) × batch
vhp # 3 × {head × point values} × residues (L) × batch
end

"""
IPACache(settings, batchsize)
Initialize an empty IPA cache.
"""
function IPACache(settings::NamedTuple, batchsize::Integer)
(; c, N_head, N_query_points, N_point_values) = settings
qh = zeros(Float32, c, N_head, 0, batchsize)
kh = zeros(Float32, c, N_head, 0, batchsize)
vh = zeros(Float32, c, N_head, 0, batchsize)
qhp = zeros(Float32, 3, N_head * N_query_points, 0, batchsize)
khp = zeros(Float32, 3, N_head * N_query_points, 0, batchsize)
vhp = zeros(Float32, 3, N_head * N_point_values, 0, batchsize)
IPACache(0, 0, batchsize, qh, kh, vh, qhp, khp, vhp)
end

function expand(
ipa::IPCrossA,
cache::IPACache,
TiL::Tuple, siL::AbstractArray, ΔL::Integer,
TiR::Tuple, siR::AbstractArray, ΔR::Integer;
zij = nothing,
mask = 0,
)
dims, c, N_head, N_query_points, N_point_values, c_z, Typ, pairwise = ipa.settings
L, R, B = cache.sizeL, cache.sizeR, cache.batchsize
layer = ipa.layers

gamma_h = min.(softplus(layer.gamma_h), 1f2)

Δqh = reshape(calldense(layer.proj_qh, @view siR[:,R+1:R+ΔR,:]), (c, N_head, ΔR, B))
Δkh = reshape(calldense(layer.proj_kh, @view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B))
Δvh = reshape(calldense(layer.proj_vh, @view siL[:,L+1:L+ΔL,:]), (c, N_head, ΔL, B))

Δqhp = reshape(calldense(layer.proj_qhp, @view siR[:,R+1:R+ΔR,:]), (3, N_head * N_query_points, ΔR, B))
Δkhp = reshape(calldense(layer.proj_khp, @view siL[:,L+1:L+ΔL,:]), (3, N_head * N_query_points, ΔL, B))
Δvhp = reshape(calldense(layer.proj_vhp, @view siL[:,L+1:L+ΔL,:]), (3, N_head * N_point_values, ΔL, B))

kh = cat(cache.kh, Δkh, dims = 3)
vh = cat(cache.vh, Δvh, dims = 3)

khp = cat(cache.khp, Δkhp, dims = 3)
vhp = cat(cache.vhp, Δvhp, dims = 3)

# calculate inner products
ΔqhT = permutedims(Δqh, (3, 1, 2, 4))
kh = permutedims(kh, (1, 3, 2, 4))
ΔqhTkh = permutedims(batched_mul(ΔqhT, kh), (3, 1, 2, 4))

# transform vector points to the global frames
rot_TiL, translate_TiL = TiL
rot_TiR, translate_TiR = TiR
ΔTqhp = reshape(T_R3(Δqhp, @view(rot_TiR[:,:,R+1:R+ΔR,:]), @view(translate_TiR[:,:,R+1:R+ΔR,:])), (3, N_head, N_query_points, ΔR, B))
Tkhp = reshape(
T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL) * B)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])),
(3, N_head, N_query_points, L + ΔL, B)
)
Tvhp = reshape(
T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL) * B)), @view(rot_TiL[:,:,1:L+ΔL,:]), @view(translate_TiL[:,:,1:L+ΔL,:])),
(3, N_head, N_point_values, L + ΔL, B)
)

diffs = unsqueeze(ΔTqhp, dims = 5) .- unsqueeze(Tkhp, dims = 4)
sum_norms = sumdrop(abs2, diffs, dims = (1, 3))

w_C = sqrt(2f0 / 9N_query_points)
dim_scale = sqrt(1f0 / c)
Δatt_logits = reshape(dim_scale .* ΔqhTkh .- w_C/2 .* gamma_h .* sum_norms, (N_head, ΔR, L + ΔL, B))

if mask != 0
mask = unsqueeze(@view(mask[R+1:R+ΔR,1:L+ΔL]), dims = 1)
end

if pairwise
bij = reshape(layer.pair(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:])), (N_head, ΔR, L + ΔL, B))
w_L = sqrt(1f0/3)
Δatt = softmax(w_L .* (Δatt_logits .+ bij) .+ mask, dims = 3)
else
w_L = sqrt(1f0/2)
Δatt = softmax(w_L .* Δatt_logits .+ mask, dims = 3)
end

# take the attention weighted sum of the value vectors
oh = sumdrop(
reshape(Δatt, (1, N_head, ΔR, L + ΔL, B)) .*
reshape( vh, (c, N_head, 1, L + ΔL, B)),
dims = 4,
)
ohp = reshape(
T_R3_inv(
reshape(
# 3 × N_head × N_point_values × ΔR × batch
sumdrop(
reshape(Δatt, (1, N_head, 1, ΔR, L + ΔL, B)) .*
reshape(Tvhp, (3, N_head, N_point_values, 1, L + ΔL, B)),
dims = 5,
),
(3, N_head * N_point_values, ΔR * B)
),
@view(rot_TiR[:,:,R+1:R+ΔR,:]),
@view(translate_TiR[:,:,R+1:R+ΔR,:])
),
(3, N_head, N_point_values, ΔR, B)
)
ohp_norms = sqrt.(sumdrop(abs2, ohp, dims = 1))

# concatenate all outputs
o = [
reshape(oh, (c * N_head, ΔR, B))
reshape(ohp, (3 * N_head * N_point_values, ΔR, B))
reshape(ohp_norms, (N_head * N_point_values, ΔR, B))
]
if pairwise
o = [
o
reshape(
sumdrop(
reshape( Δatt, ( 1, N_head, ΔR, L + ΔL, B)) .*
reshape(@view(zij[:,R+1:R+ΔR,1:L+ΔL,:]), (c_z, 1, ΔR, L + ΔL, B)),
dims = 4
),
(c_z * N_head, ΔR, B)
)
]
end

cache = IPACache(
L + ΔL,
R + ΔR,
B,
cat(cache.qh, Δqh, dims = 3),
cat(cache.kh, Δkh, dims = 3),
cat(cache.vh, Δvh, dims = 3),
cat(cache.qhp, Δqhp, dims = 3),
cat(cache.khp, Δkhp, dims = 3),
cat(cache.vhp, Δvhp, dims = 3),
)
layer.ipa_linear(o), cache
end

sumdrop(x; dims) = dropdims(sum(x; dims); dims)
sumdrop(f, x; dims) = dropdims(sum(f, x; dims); dims)

# dense(x) to avoid https://github.com/FluxML/Flux.jl/issues/2407
function calldense(dense::Dense, x::AbstractArray)
d1 = size(dense.weight, 1)
reshape(dense(reshape(x, size(x, 1), :)), d1, size(x)[2:end]...)
end
31 changes: 31 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
using InvariantPointAttention
using InvariantPointAttention: get_rotation, get_translation
using Test

@testset "InvariantPointAttention.jl" begin
# Write your tests here.


@testset "IPACache" begin
dims = 8
c_z = 2
settings = IPA_settings(dims; c_z)
ipa = IPCrossA(settings)

# generate random data
L = 5
R = 6
B = 4
siL = randn(Float32, dims, L, B)
siR = randn(Float32, dims, R, B)
zij = randn(Float32, c_z, R, L, B)
TiL = (get_rotation(L, B), get_translation(L, B))
TiR = (get_rotation(R, B), get_translation(R, B))

# check the consistency
cache = InvariantPointAttention.IPACache(settings, B)
siR′, cache′ = InvariantPointAttention.expand(ipa, cache, TiL, siL, L, TiR, siR, R; zij)
@test size(siR′) == size(siR)
@test siR′ ipa(TiL, siL, TiR, siR; zij)

# calculate in two steps
cache = InvariantPointAttention.IPACache(settings, B)
siR1, cache = InvariantPointAttention.expand(ipa, cache, TiL, siL, L, TiR, siR, 2; zij)
siR2, cache = InvariantPointAttention.expand(ipa, cache, TiL, siL, 0, TiR, siR, 4; zij)
@test cat(siR1, siR2, dims = 2) ipa(TiL, siL, TiR, siR; zij)
end
end

0 comments on commit 1b6320a

Please sign in to comment.