Skip to content

Commit

Permalink
Update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 11, 2024
1 parent d68aa96 commit 9ddd90f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 90 deletions.
92 changes: 48 additions & 44 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,53 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), x::AbstractArray{T,N}, R::Abst
return y, T_R3_inv_pullback
end

"""
softmax1(x, dims = 1)
Behaves like softmax, but as though there was an additional logit of zero along dims (which is excluded from the output). So the values will sum to a value between zero and 1.
See https://www.evanmiller.org/attention-is-off-by-one.html
"""
function softmax1(x::AbstractArray{T}; dims = 1) where {T}
_zero = T(0)
max_ = max.(fast_maximum2(x; dims), _zero)
@fastmath out = exp.(x .- max_)
tmp = sum(out, dims = dims)
out ./ (tmp + exp.(-max_))
end

# taken from NNlib
fast_maximum2(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))

function ∇softmax1_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
dx = if NNlib.within_gradient(y)
tmp = dy .* y
tmp .- y .* sum(tmp; dims)
else
# This path is faster, only safe for 1st derivatives though.
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
out = similar(y, promote_type(T,S)) # sure to be mutable
out .= dy .* y
out .= out .- y .* sum(out; dims)
end
end

function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1)
y = softmax1(x; dims)
softmax_pullback(dy) = (NoTangent(), ∇softmax1_data(unthunk(dy), y; dims))
return y, softmax_pullback
end


function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T
w_C = T(sqrt(2f0/(9f0*size(qhp,3))))
dim_scale = T(1f0/sqrt(size(qh,1)))
w_L = T(1f0/sqrt(3f0))

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3)))
end

#=
function diff_sum_glob(T, q, k)
bs = size(q)
Expand Down Expand Up @@ -151,47 +198,4 @@ function ChainRulesCore.rrule(::typeof(qhTkh), q, k)
end
return qhTkh, qhTkh_pullback
end
=#
"""
softmax1(x, dims = 1)
Behaves like softmax, but as though there was an additional logit of zero along dims (which is excluded from the output). So the values will sum to a value between zero and 1.
"""
function softmax1(x::AbstractArray{T}; dims = 1) where {T}
_zero = T(0)
max_ = max.(fast_maximum2(x; dims), _zero)
@fastmath out = exp.(x .- max_)
tmp = sum(out, dims = dims)
out ./ (tmp + exp.(-max_))
end
# Pirated/adapted from NNlib
fast_maximum2(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))

function ∇softmax1_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
dx = if NNlib.within_gradient(y)
tmp = dy .* y
tmp .- y .* sum(tmp; dims)
else
# This path is faster, only safe for 1st derivatives though.
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
out = similar(y, promote_type(T,S)) # sure to be mutable
out .= dy .* y
out .= out .- y .* sum(out; dims)
end
end

function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1)
y = softmax1(x; dims)
softmax_pullback(dy) = (NoTangent(), ∇softmax1_data(unthunk(dy), y; dims))
return y, softmax_pullback
end


function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T
w_C = T(sqrt(2f0/(9f0*size(qhp,3))))
dim_scale = T(1f0/sqrt(size(qh,1)))
w_L = T(1f0/sqrt(3f0))

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3)))
end
=#
6 changes: 1 addition & 5 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,11 +399,10 @@ function IPCrossAStructureModuleLayer(settings::NamedTuple; dropout_p = 0.1, af
return IPCrossAStructureModuleLayer(settings, layers)
end

# We could skip making this struct and just have it be a cross IPA struct.
"""
Self IPA Partial Structure Module initialization - single layer - adapted from AF2.
"""

# We could skip making this struct and just have it be a cross IPA struct.
struct IPAStructureModuleLayer
settings::NamedTuple
layers::NamedTuple
Expand All @@ -418,9 +417,6 @@ function (structuremodulelayer::Union{IPAStructureModuleLayer, IPCrossAStructure
return structuremodulelayer(T, S, T, S; zij = zij, mask = mask)
end

"""
Cross IPA Partial Structure Module - single layer - adapted from AF2. From left to right.
"""
function (structuremodulelayer::Union{IPCrossAStructureModuleLayer, IPAStructureModuleLayer})(T_L, S_L, T_R, S_R; zij = nothing, mask = 0)
settings = structuremodulelayer.settings
if settings.c_z > 0 && zij === nothing
Expand Down
62 changes: 23 additions & 39 deletions src/rotational_utils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
# from AF2 supplementary: Algorithm 23 Backbone update
"""
Takes a 3xN matrix of imaginary quaternion components, `bcd`, sets the real part to `a`, and normalizes to unit quaternions.
"""
function bcds2quats(bcd::AbstractMatrix{T}, a::T=T(1)) where T <: Real
#= from AF2 supplementary: Algorithm 23 Backbone update
Takes a 3xN matrix of imaginary quaternion components, `bcd`, sets the real part to `a`, and normalizes to unit quaternions. =#
function bcds2quats(bcd::AbstractMatrix{T}, a::T=T(1)) where T<:Real
norms = sqrt.(a .+ sum(abs2, bcd, dims=1))
return vcat(a ./ norms, bcd ./ norms)
end

"""
Takes a 4xN matrix of unit quaternions and returns a 3x3xN array of rotation matrices.
"""
# Takes a 4xN matrix of unit quaternions and returns a 3x3xN array of rotation matrices.
function rotmatrix_from_quat(q::AbstractMatrix{<:Real})
sx = 2q[1, :] .* q[2, :]
sy = 2q[1, :] .* q[3, :]
Expand Down Expand Up @@ -73,12 +69,9 @@ function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end


"""
Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N
#= Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N
to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N).
Note here that rotations here are represented in matrix form.
"""
Note here that rotations here are represented in matrix form. =#
function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T
x′ = reshape(x, 3, size(x, 2), :)
R′ = reshape(R, 3, 3, :)
Expand All @@ -88,10 +81,8 @@ function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) whe
return y
end

"""
Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3,
such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x.
"""
#= Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3,
such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x. =#
function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T
y′ = reshape(y, 3, size(y, 2), :)
R′ = reshape(R, 3, 3, :)
Expand All @@ -101,9 +92,7 @@ function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T})
return x
end

"""
Returns the composition of two SE(3) transformations T_1 and T_2. If T1 = (R1,t1), and T2 = (R2,t2) then T1*T2 = (R1*R2, R1*t2 + t1).
"""
# Returns the composition of two SE(3) transformations T_1 and T_2. If T1 = (R1,t1), and T2 = (R2,t2) then T1*T2 = (R1*R2, R1*t2 + t1).
function T_T(T_1, T_2)
R1, t1 = T_1
R2, t2 = T_2
Expand All @@ -112,9 +101,7 @@ function T_T(T_1, T_2)
return (new_rot,new_trans)
end

"""
Takes a 6-dim vec and maps to a rotation matrix and translation vector, which is then applied to the input frames.
"""
# Takes a 6-dim vec and maps to a rotation matrix and translation vector, which is then applied to the input frames.
function update_frame(Ti, arr)
bcds = reshape(arr[:,1,:,:],3,:)
rotmat = rotmatrix_from_quat(bcds2quats(bcds))
Expand All @@ -130,9 +117,7 @@ unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))

centroid(coords::AbstractMatrix) = vec(sum(coords; dims=2)) / size(coords, 2)

"""
Get frame from residue
"""
# Get rotation matrix and translation from residue
function calculate_residue_rotation_and_translation(residue_xyz::AbstractMatrix)
# Returns the rotation matrix and the translation of a gien residue.
N = residue_xyz[:, 1]
Expand All @@ -152,30 +137,29 @@ function calculate_residue_rotation_and_translation(residue_xyz::AbstractMatrix)
end

"""
Get the assosciated SE(3) frame for all residues in a prot
get_T(coords::Array{<:Real, 3})
Get the assosciated SE(3) frame for all residues in a protein backbone represented as a 3x3xL array of coordinates.
"""
function get_T(protxyz::Array{<:Real, 3})
ti = stack.(unzip([calculate_residue_rotation_and_translation(protxyz[:,:,i]) for i in axes(protxyz,3)]))
return (ti[1],reshape(ti[2],3,1,:))
function get_T(coords::Array{<:Real, 3})
Ti = stack.(unzip([calculate_residue_rotation_and_translation(coords[:,:,i]) for i in axes(coords,3)]))
return (Ti[1],reshape(Ti[2],3,1,:))
end

"""
Get the assosciated SE(3) frames for all residues in a batch of prots
Get the associated SE(3) frames for all residues in a batch of proteins
"""
function get_T_batch(protxyz::Array{<:Real, 4})
rots = zeros(3,3,size(protxyz)[3:4]...)
trans = zeros(3,1,size(protxyz)[3:4]...)
for j in axes(protxyz,4)
Tij = get_T(protxyz[:,:,:,j])
function get_T_batch(coords::Array{<:Real, 4})
rots = zeros(3,3,size(coords)[3:4]...)
trans = zeros(3,1,size(coords)[3:4]...)
for j in axes(coords,4)
Tij = get_T(coords[:,:,:,j])
rots[:,:,:,j] = Tij[1]
trans[:,:,:,j] = Tij[2]
end
return (rots, trans)
end

"""
Index into a T up to index i.
"""
# Index into a T up to index i.
function T_till(T,i)
Tr, Tt = T[1][:,:,1:i,:], T[2][:,:,1:i,:]
return Tr, Tt
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ using ChainRulesTestUtils

# Get 1 global SE(3) transformation for each batch.
T_glob = (get_rotation(batch_size), get_translation(batch_size))
T_GlobL = (stack([T_glob[1] for i in 1:framesL],dims = 3), stack([T_glob[2] for i in 1:framesL], dims = 3))
T_GlobR = (stack([T_glob[1] for i in 1:framesR],dims = 3), stack([T_glob[2] for i in 1:framesR], dims = 3))
T_GlobL = (stack([T_glob[1] for i in 1:framesL],dims = 3), stack([T_glob[2] for i in 1:framesL], dims=3))
T_GlobR = (stack([T_glob[1] for i in 1:framesR],dims = 3), stack([T_glob[2] for i in 1:framesR], dims=3))

T_newL = InvariantPointAttention.T_T(T_GlobL,T_locL)
T_newR = InvariantPointAttention.T_T(T_GlobR,T_locR)
Expand Down

0 comments on commit 9ddd90f

Please sign in to comment.