From 9ddd90f2d650eec11a1229170d9519540a3cea62 Mon Sep 17 00:00:00 2001 From: anton083 Date: Tue, 11 Jun 2024 13:15:29 +0200 Subject: [PATCH] Update docstrings --- src/grads.jl | 92 +++++++++++++++++++++-------------------- src/layers.jl | 6 +-- src/rotational_utils.jl | 62 +++++++++++---------------- test/runtests.jl | 4 +- 4 files changed, 74 insertions(+), 90 deletions(-) diff --git a/src/grads.jl b/src/grads.jl index fdaf29e..cf51eb1 100644 --- a/src/grads.jl +++ b/src/grads.jl @@ -69,6 +69,53 @@ function ChainRulesCore.rrule(::typeof(T_R3_inv), x::AbstractArray{T,N}, R::Abst return y, T_R3_inv_pullback end +""" + softmax1(x, dims = 1) + +Behaves like softmax, but as though there was an additional logit of zero along dims (which is excluded from the output). So the values will sum to a value between zero and 1. + +See https://www.evanmiller.org/attention-is-off-by-one.html +""" +function softmax1(x::AbstractArray{T}; dims = 1) where {T} + _zero = T(0) + max_ = max.(fast_maximum2(x; dims), _zero) + @fastmath out = exp.(x .- max_) + tmp = sum(out, dims = dims) + out ./ (tmp + exp.(-max_)) +end + +# taken from NNlib +fast_maximum2(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf)) + +function ∇softmax1_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S} + dx = if NNlib.within_gradient(y) + tmp = dy .* y + tmp .- y .* sum(tmp; dims) + else + # This path is faster, only safe for 1st derivatives though. + # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads, + # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30 + out = similar(y, promote_type(T,S)) # sure to be mutable + out .= dy .* y + out .= out .- y .* sum(out; dims) + end +end + +function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1) + y = softmax1(x; dims) + softmax_pullback(dy) = (NoTangent(), ∇softmax1_data(unthunk(dy), y; dims)) + return y, softmax_pullback +end + + +function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T + w_C = T(sqrt(2f0/(9f0*size(qhp,3)))) + dim_scale = T(1f0/sqrt(size(qh,1))) + w_L = T(1f0/sqrt(3f0)) + + w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3))) +end + #= function diff_sum_glob(T, q, k) bs = size(q) @@ -151,47 +198,4 @@ function ChainRulesCore.rrule(::typeof(qhTkh), q, k) end return qhTkh, qhTkh_pullback end -=# -""" -softmax1(x, dims = 1) - -Behaves like softmax, but as though there was an additional logit of zero along dims (which is excluded from the output). So the values will sum to a value between zero and 1. -""" -function softmax1(x::AbstractArray{T}; dims = 1) where {T} - _zero = T(0) - max_ = max.(fast_maximum2(x; dims), _zero) - @fastmath out = exp.(x .- max_) - tmp = sum(out, dims = dims) - out ./ (tmp + exp.(-max_)) -end -# Pirated/adapted from NNlib -fast_maximum2(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf)) - -function ∇softmax1_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S} - dx = if NNlib.within_gradient(y) - tmp = dy .* y - tmp .- y .* sum(tmp; dims) - else - # This path is faster, only safe for 1st derivatives though. - # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads, - # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30 - out = similar(y, promote_type(T,S)) # sure to be mutable - out .= dy .* y - out .= out .- y .* sum(out; dims) - end -end - -function ChainRulesCore.rrule(::typeof(softmax1), x; dims = 1) - y = softmax1(x; dims) - softmax_pullback(dy) = (NoTangent(), ∇softmax1_data(unthunk(dy), y; dims)) - return y, softmax_pullback -end - - -function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::AbstractArray{T},khp::AbstractArray{T}, bij::AbstractArray{T}, gamma_h::AbstractArray{T}) where T - w_C = T(sqrt(2f0/(9f0*size(qhp,3)))) - dim_scale = T(1f0/sqrt(size(qh,1))) - w_L = T(1f0/sqrt(3f0)) - - w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3))) -end +=# \ No newline at end of file diff --git a/src/layers.jl b/src/layers.jl index d8bd34e..9a9f58e 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -399,11 +399,10 @@ function IPCrossAStructureModuleLayer(settings::NamedTuple; dropout_p = 0.1, af return IPCrossAStructureModuleLayer(settings, layers) end +# We could skip making this struct and just have it be a cross IPA struct. """ Self IPA Partial Structure Module initialization - single layer - adapted from AF2. """ - -# We could skip making this struct and just have it be a cross IPA struct. struct IPAStructureModuleLayer settings::NamedTuple layers::NamedTuple @@ -418,9 +417,6 @@ function (structuremodulelayer::Union{IPAStructureModuleLayer, IPCrossAStructure return structuremodulelayer(T, S, T, S; zij = zij, mask = mask) end -""" -Cross IPA Partial Structure Module - single layer - adapted from AF2. From left to right. -""" function (structuremodulelayer::Union{IPCrossAStructureModuleLayer, IPAStructureModuleLayer})(T_L, S_L, T_R, S_R; zij = nothing, mask = 0) settings = structuremodulelayer.settings if settings.c_z > 0 && zij === nothing diff --git a/src/rotational_utils.jl b/src/rotational_utils.jl index 2780dd3..3145d8c 100644 --- a/src/rotational_utils.jl +++ b/src/rotational_utils.jl @@ -1,15 +1,11 @@ -# from AF2 supplementary: Algorithm 23 Backbone update -""" -Takes a 3xN matrix of imaginary quaternion components, `bcd`, sets the real part to `a`, and normalizes to unit quaternions. -""" -function bcds2quats(bcd::AbstractMatrix{T}, a::T=T(1)) where T <: Real +#= from AF2 supplementary: Algorithm 23 Backbone update +Takes a 3xN matrix of imaginary quaternion components, `bcd`, sets the real part to `a`, and normalizes to unit quaternions. =# +function bcds2quats(bcd::AbstractMatrix{T}, a::T=T(1)) where T<:Real norms = sqrt.(a .+ sum(abs2, bcd, dims=1)) return vcat(a ./ norms, bcd ./ norms) end -""" -Takes a 4xN matrix of unit quaternions and returns a 3x3xN array of rotation matrices. -""" +# Takes a 4xN matrix of unit quaternions and returns a 3x3xN array of rotation matrices. function rotmatrix_from_quat(q::AbstractMatrix{<:Real}) sx = 2q[1, :] .* q[2, :] sy = 2q[1, :] .* q[3, :] @@ -73,12 +69,9 @@ function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T return reshape(z, size(z, 1), size(z, 2), batch_size...) end - -""" -Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N +#= Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N). -Note here that rotations here are represented in matrix form. -""" +Note here that rotations here are represented in matrix form. =# function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T x′ = reshape(x, 3, size(x, 2), :) R′ = reshape(R, 3, 3, :) @@ -88,10 +81,8 @@ function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) whe return y end -""" -Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3, -such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x. -""" +#= Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3, +such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x. =# function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T y′ = reshape(y, 3, size(y, 2), :) R′ = reshape(R, 3, 3, :) @@ -101,9 +92,7 @@ function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) return x end -""" -Returns the composition of two SE(3) transformations T_1 and T_2. If T1 = (R1,t1), and T2 = (R2,t2) then T1*T2 = (R1*R2, R1*t2 + t1). -""" +# Returns the composition of two SE(3) transformations T_1 and T_2. If T1 = (R1,t1), and T2 = (R2,t2) then T1*T2 = (R1*R2, R1*t2 + t1). function T_T(T_1, T_2) R1, t1 = T_1 R2, t2 = T_2 @@ -112,9 +101,7 @@ function T_T(T_1, T_2) return (new_rot,new_trans) end -""" -Takes a 6-dim vec and maps to a rotation matrix and translation vector, which is then applied to the input frames. -""" +# Takes a 6-dim vec and maps to a rotation matrix and translation vector, which is then applied to the input frames. function update_frame(Ti, arr) bcds = reshape(arr[:,1,:,:],3,:) rotmat = rotmatrix_from_quat(bcds2quats(bcds)) @@ -130,9 +117,7 @@ unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a))) centroid(coords::AbstractMatrix) = vec(sum(coords; dims=2)) / size(coords, 2) -""" -Get frame from residue -""" +# Get rotation matrix and translation from residue function calculate_residue_rotation_and_translation(residue_xyz::AbstractMatrix) # Returns the rotation matrix and the translation of a gien residue. N = residue_xyz[:, 1] @@ -152,30 +137,29 @@ function calculate_residue_rotation_and_translation(residue_xyz::AbstractMatrix) end """ -Get the assosciated SE(3) frame for all residues in a prot + get_T(coords::Array{<:Real, 3}) +Get the assosciated SE(3) frame for all residues in a protein backbone represented as a 3x3xL array of coordinates. """ -function get_T(protxyz::Array{<:Real, 3}) - ti = stack.(unzip([calculate_residue_rotation_and_translation(protxyz[:,:,i]) for i in axes(protxyz,3)])) - return (ti[1],reshape(ti[2],3,1,:)) +function get_T(coords::Array{<:Real, 3}) + Ti = stack.(unzip([calculate_residue_rotation_and_translation(coords[:,:,i]) for i in axes(coords,3)])) + return (Ti[1],reshape(Ti[2],3,1,:)) end """ -Get the assosciated SE(3) frames for all residues in a batch of prots +Get the associated SE(3) frames for all residues in a batch of proteins """ -function get_T_batch(protxyz::Array{<:Real, 4}) - rots = zeros(3,3,size(protxyz)[3:4]...) - trans = zeros(3,1,size(protxyz)[3:4]...) - for j in axes(protxyz,4) - Tij = get_T(protxyz[:,:,:,j]) +function get_T_batch(coords::Array{<:Real, 4}) + rots = zeros(3,3,size(coords)[3:4]...) + trans = zeros(3,1,size(coords)[3:4]...) + for j in axes(coords,4) + Tij = get_T(coords[:,:,:,j]) rots[:,:,:,j] = Tij[1] trans[:,:,:,j] = Tij[2] end return (rots, trans) end -""" -Index into a T up to index i. -""" +# Index into a T up to index i. function T_till(T,i) Tr, Tt = T[1][:,:,1:i,:], T[2][:,:,1:i,:] return Tr, Tt diff --git a/test/runtests.jl b/test/runtests.jl index e07da30..061267a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -100,8 +100,8 @@ using ChainRulesTestUtils # Get 1 global SE(3) transformation for each batch. T_glob = (get_rotation(batch_size), get_translation(batch_size)) - T_GlobL = (stack([T_glob[1] for i in 1:framesL],dims = 3), stack([T_glob[2] for i in 1:framesL], dims = 3)) - T_GlobR = (stack([T_glob[1] for i in 1:framesR],dims = 3), stack([T_glob[2] for i in 1:framesR], dims = 3)) + T_GlobL = (stack([T_glob[1] for i in 1:framesL],dims = 3), stack([T_glob[2] for i in 1:framesL], dims=3)) + T_GlobR = (stack([T_glob[1] for i in 1:framesR],dims = 3), stack([T_glob[2] for i in 1:framesR], dims=3)) T_newL = InvariantPointAttention.T_T(T_GlobL,T_locL) T_newR = InvariantPointAttention.T_T(T_GlobR,T_locR)