Skip to content

Commit

Permalink
Fix T_R3 and T_R3_inv
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 7, 2024
1 parent 176d1dc commit 2559937
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 171 deletions.
106 changes: 23 additions & 83 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ function sumabs2(x::AbstractArray{T}; dims = 1) where {T}
sum(abs2, x; dims)
end

function _sumabs2_no_rrule(x::AbstractArray{T}; dims = 1) where {T}
sum(abs2, x; dims)
end

function ChainRulesCore.rrule(::typeof(sumabs2), x; dims = 1)
function sumabs2_pullback(_Δ)
Δ = unthunk(_Δ)
Expand All @@ -27,10 +23,6 @@ function L2norm(x::AbstractArray{T}; dims = 1, eps = 1f-7) where {T}
sqrt.(sumabs2(x; dims) .+ eps )
end

function _L2norm_no_rrule(x::AbstractArray{T}; dims = 1, eps = 1f-7) where {T}
sqrt.(sum(abs2, x; dims) .+ eps )
end

function ChainRulesCore.rrule(::typeof(L2norm), x::AbstractArray{T}; dims = 1, eps = 1f-7) where {T}
normx = L2norm(x; dims, eps)
function L2norm_pullback(_Δ)
Expand All @@ -44,91 +36,39 @@ function pair_diff(A::AbstractArray{T}, B::AbstractArray{T}; dims = 4) where {T}
return Flux.unsqueeze(A, dims = dims + 1) .- Flux.unsqueeze(B, dims = dims)
end

function _pair_diff_no_rrule(A::AbstractArray{T}, B::AbstractArray{T}; dims = 4) where {T}
return Flux.unsqueeze(A, dims = dims + 1) .- Flux.unsqueeze(B, dims = dims)
end

function ChainRulesCore.rrule(::typeof(pair_diff), A::AbstractArray{T}, B::AbstractArray{T}; dims = 4) where {T}
function ChainRulesCore.rrule(::typeof(pair_diff), A::AbstractArray{T}, B::AbstractArray{T}; dims=4) where {T}
y = pair_diff(A, B; dims)
function pair_diff_pullback(_Δ)
Δ = unthunk(_Δ)
return (NoTangent(), @thunk(sumdrop(Δ; dims = dims + 1)), @thunk(-sumdrop(Δ; dims = dims)))
return (NoTangent(), @thunk(sumdrop(Δ; dims=dims+1)), @thunk(-sumdrop(Δ; dims=dims)))
end
return y, pair_diff_pullback
end

function ChainRulesCore.rrule(::typeof(T_R3), A, R, t; dims = 1)
function T_R3_pullback(_Δ)
Δ = unthunk(_Δ)
ΔA = @thunk begin
batch_size = size(A)[3:end]
R2 = reshape(R, size(R,1), size(R,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
ΔA = batched_mul(batched_adjoint(R2), Δ2)
reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...)
end
ΔR = @thunk begin
batch_size = size(R)[3:end]
A2 = reshape(A, size(A,1), size(A,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
ΔR = batched_mul(Δ2, batched_adjoint(A2))
reshape(ΔR, size(ΔR, 1), size(ΔR, 2), batch_size...)
end
Δt = @thunk begin
# Case for broadcasting t along dim = 2.
size(t,2) == 1 ? tmp = sum(Δ, dims = 2) : tmp = Δ
tmp
end
return (NoTangent(), ΔA, ΔR, Δt)
function ChainRulesCore.rrule(::typeof(T_R3), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
function T_R3_pullback(_Δy)
Δy = unthunk(_Δy)
Δx = @thunk(batched_mul(_batched_transpose(R), Δy))
ΔR = @thunk(batched_mul(Δy, _batched_transpose(x)))
Δt = @thunk(sum(Δy, dims=2))
return (NoTangent(), Δx, ΔR, Δt)
end
return T_R3(A, R, t), T_R3_pullback
end

function _T_R3_no_rrule(mat, rot,trans)
size_mat = size(mat)
rotc = reshape(rot, 3,3,:)
trans = reshape(trans, 3,1,:)
matc = reshape(mat,3,size(mat,2),:)
rotated_mat = batched_mul(rotc,matc) .+ trans
return reshape(rotated_mat,size_mat)
end

function ChainRulesCore.rrule(::typeof(T_R3_inv), A, R, t; dims = 1)
function T_R3_inv_pullback(_Δ)
Δ = unthunk(_Δ)
ΔA = @thunk begin
batch_size = size(A)[3:end]
R2 = reshape(R, size(R,1), size(R,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
ΔA = batched_mul(R2, Δ2)
reshape(ΔA, size(ΔA, 1), size(ΔA, 2), batch_size...)
end

ΔR = @thunk begin
batch_size = size(R)[3:end]
A2 = reshape(A, size(A,1), size(A,2), :)
Δ2 = reshape(Δ, size(Δ,1), size(Δ,2), :)
ΔR = batched_mul(A2, batched_adjoint(Δ2))
reshape(ΔR, size(ΔR, 1), size(ΔR, 2), batch_size...)
end
Δt = @thunk begin
# Case for broadcasting t along dim = 2.
size(t,2) == 1 ? tmp = sum(Δ, dims = 2) : tmp = Δ
tmp
end
return (NoTangent(), ΔA, ΔR, Δt)
return T_R3(x, R, t), T_R3_pullback
end

function ChainRulesCore.rrule(::typeof(T_R3_inv), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
z = x .- t
y = batched_mul(_batched_transpose(R), z)
function T_R3_inv_pullback(_Δy)
Δy = unthunk(_Δy)
Δx = @thunk(batched_mul(R, Δy))
ΔR = @thunk(batched_mul(z, _batched_transpose(Δy)))
Δt = @thunk(-sum(Δx, dims=2)) # t is in the same position as x, but negated and broadcasted
return (NoTangent(), Δx, ΔR, Δt)
end
return T_R3_inv(A, R, t), T_R3_inv_pullback
return T_R3_inv(x, R, t), T_R3_inv_pullback
end

function _T_R3_inv_no_rrule(mat, rot,trans)
size_mat = size(mat)
rotc = batched_transpose(reshape(rot, 3,3,:))
matc = reshape(mat,3,size(mat,2),:)
trans = reshape(trans, 3,1,:)
rotated_mat = batched_mul(rotc,matc .- trans)
return reshape(rotated_mat,size_mat)
end
#=
function diff_sum_glob(T, q, k)
bs = size(q)
Expand All @@ -143,7 +83,7 @@ function _diff_sum_glob_no_rrule(T,q,k)
bs = size(q)
qresh = reshape(q, size(q,1), size(q,2)*size(q,3), size(q,4),size(q,5))
kresh = reshape(k, size(k,1), size(k,2)*size(k,3), size(k,4),size(k,5))
Tq, Tk = _T_R3_no_rrule(qresh,T[1],T[2]),_T_R3_no_rrule(kresh,T[1],T[2])
Tq, Tk = T_R3_no_rrule(qresh,T[1],T[2]),T_R3_no_rrule(kresh,T[1],T[2])
Tq, Tk = reshape(Tq, bs...), reshape(Tk, bs...)
diffs = _sumabs2_no_rrule(_pair_diff_no_rrule(Tq, Tk, dims = 4),dims=[1,3])
end=#
Expand Down
21 changes: 10 additions & 11 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ function (ipa::Union{IPCrossA, IPA})(
,(3,1,2,4))

# Applying our transformations to the queries, keys, and values to put them in the global frame.
Tqhp = reshape(_T_R3_no_rrule(qhp, rot_TiR,translate_TiR),3,N_head,N_query_points,N_frames_R,:)
Tkhp = reshape(_T_R3_no_rrule(khp, rot_TiL,translate_TiL),3,N_head,N_query_points,N_frames_L,:)
Tvhp = _T_R3_no_rrule(vhp, rot_TiL, translate_TiL)
Tqhp = reshape(T_R3(qhp, rot_TiR,translate_TiR),3,N_head,N_query_points,N_frames_R,:)
Tkhp = reshape(T_R3(khp, rot_TiL,translate_TiL),3,N_head,N_query_points,N_frames_L,:)
Tvhp = T_R3(vhp, rot_TiL, translate_TiL)

diffs_glob = Flux.unsqueeze(Tqhp, dims = 5) .- Flux.unsqueeze(Tkhp, dims = 4)
sum_norms_glob = reshape(sum(abs2, diffs_glob, dims = [1,3]),N_head,N_frames_R,N_frames_L,:) #Sum over points for each head
Expand Down Expand Up @@ -233,7 +233,7 @@ function (ipa::Union{IPCrossA, IPA})(
end

#ohp_r were in the global frame, so we put those ba ck in the recipient local
ohp = _T_R3_inv_no_rrule(ohp_r, rot_TiR, translate_TiR)
ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR)
normed_ohp = sqrt.(sum(abs2, ohp,dims = 1) .+ Typ(0.000001f0)) #Adding eps

catty = vcat(
Expand Down Expand Up @@ -350,7 +350,7 @@ function ipa_customgrad(ipa::Union{IPCrossA, IPA}, Ti::Tuple{AbstractArray,Abstr
ohp_r = reshape(sum(broadcast_att_ohp.*broadcast_tvhp,dims=5),3,N_head*N_point_values,N_frames_R,:)
end
#ohp_r were in the global frame, so we put those back in the recipient local
ohp = _T_R3_inv_no_rrule(ohp_r, rot_TiR, translate_TiR)
ohp = T_R3_inv(ohp_r, rot_TiR, translate_TiR)
normed_ohp = sqrt.(sumabs2(ohp, dims = 1) .+ Typ(0.000001f0)) #Adding eps
catty = vcat(
reshape(oh, N_head*c, N_frames_R,:),
Expand Down Expand Up @@ -510,11 +510,11 @@ function expand(
rot_TiR, translate_TiR = TiR
ΔTqhp = reshape(T_R3(Δqhp, (rot_TiR[:,:,R+1:R+ΔR,:]), (translate_TiR[:,:,R+1:R+ΔR,:])), (3, N_head, N_query_points, ΔR, B))
Tkhp = reshape(
T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL) * B)), (rot_TiL[:,:,1:L+ΔL,:]), (translate_TiL[:,:,1:L+ΔL,:])),
T_R3(reshape(khp, (3, N_head * N_query_points, (L + ΔL), B)), (rot_TiL[:,:,1:L+ΔL,:]), (translate_TiL[:,:,1:L+ΔL,:])),
(3, N_head, N_query_points, L + ΔL, B)
)
Tvhp = reshape(
T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL) * B)), (rot_TiL[:,:,1:L+ΔL,:]), (translate_TiL[:,:,1:L+ΔL,:])),
T_R3(reshape(vhp, (3, N_head * N_point_values, (L + ΔL), B)), (rot_TiL[:,:,1:L+ΔL,:]), (translate_TiL[:,:,1:L+ΔL,:])),
(3, N_head, N_point_values, L + ΔL, B)
)

Expand Down Expand Up @@ -572,7 +572,7 @@ function expand(
) .+
reshape(translate_TiR[:,:,R+1:R+ΔR,:], (3, 1, 1, ΔR, B)) .*
reshape(1 .- sum(Δatt, dims = 3), (1, N_head, 1, ΔR, B)),
(3, N_head * N_point_values, ΔR * B)
(3, N_head * N_point_values, ΔR, B)
)
else
ohp_pre = reshape(
Expand All @@ -582,14 +582,13 @@ function expand(
reshape(Tvhp, (3, N_head, N_point_values, 1, L + ΔL, B)),
dims = 5,
),
(3, N_head * N_point_values, ΔR * B)
(3, N_head * N_point_values, ΔR, B)
)
end

ohp = reshape(
T_R3_inv(
ohp_pre
,
ohp_pre,
(rot_TiR[:,:,R+1:R+ΔR,:]),
(translate_TiR[:,:,R+1:R+ΔR,:])
),
Expand Down
27 changes: 12 additions & 15 deletions src/rotational_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,28 @@ get_rotation(dims...; T::Type{<:Real}=Float32) = get_rotation(T, dims...)
Generates random translations of given size.
"""
get_translation(T::Type{<:Real}, dims...) = randn(T, 3, 1, dims...)
get_translation(dims...; T::Type{<:Real}=Float32) = get_translation(T, dims...)
get_translation(dims...; T::Type{<:Real}=Float32) = get_translation(T, dims...)

function _batched_transpose(data::A) where {T,N,A<:AbstractArray{T,N}}
perm = (2,1,3:N...)
PermutedDimsArray{T,N,perm,perm,A}(data)
end

"""
Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N
to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N).
Note here that rotations here are represented in matrix form.
"""
function T_R3(mat, rot, trans)
rotc = reshape(rot, 3, 3, :)
trans = reshape(trans, 3, 1, :)
matc = reshape(mat, 3, size(mat, 2), :)
rotated_mat = batched_mul(rotc, matc) .+ trans
return reshape(rotated_mat, size(mat))
function T_R3(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
return batched_mul(R, x) .+ t
end

"""
Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3,
such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x.
"""
function T_R3_inv(mat, rot, trans)
rotc = batched_transpose(reshape(rot, 3, 3, :))
matc = reshape(mat, 3, size(mat, 2), :)
trans = reshape(trans, 3,1,:)
rotated_mat = batched_mul(rotc, matc .- trans)
return reshape(rotated_mat, size(mat))
function T_R3_inv(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
return batched_mul(_batched_transpose(R), x .- t)
end

"""
Expand Down Expand Up @@ -106,7 +103,7 @@ end

unzip(a) = map(x->getfield.(a, x), fieldnames(eltype(a)))

calculate_residue_centroid(residue_xyz::AbstractMatrix) = reshape(mean(residue_xyz[:, 1:3], dims = 2), 3)
centroid(coords::AbstractMatrix) = vec(sum(coords; dims=2)) / size(coords, 2)

"""
Get frame from residue
Expand All @@ -117,7 +114,7 @@ function calculate_residue_rotation_and_translation(residue_xyz::AbstractMatrix)
Ca = residue_xyz[:, 2] # We use the centroid instead of the Ca - not 100% sure if this is correct
C = residue_xyz[:, 3]

t = calculate_residue_centroid(residue_xyz)
t = centroid(residue_xyz)

v1 = C - t
v2 = N - t
Expand Down
Loading

0 comments on commit 2559937

Please sign in to comment.