Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 11, 2024
1 parent 0114001 commit 9d89985
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
7 changes: 3 additions & 4 deletions src/InvariantPointAttention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ include("grads.jl")
include("layers.jl")
include("masks.jl")

export IPA
export IPAStructureModuleLayer
export BackboneUpdate
export IPA_settings
export IPCrossA
export IPA, IPCrossA
export IPAStructureModuleLayer, IPCrossAStructureModuleLayer
export BackboneUpdate
export right_to_left_mask
export left_to_right_mask
export virtual_residues
Expand Down
18 changes: 17 additions & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ function (backboneupdate::BackboneUpdate)(Ti, si)
end

"""
IPA_settings(
dims;
c = 16,
N_head = 12,
N_query_points = 4,
N_point_values = 8,
c_z = 0,
Typ = Float32,
use_softmax1 = false,
scaling_qk = :default,
)
Returns a tuple of the IPA settings, with defaults for everything except dims. This can be passed to the IPA and IPCrossAStructureModuleLayer.
"""
IPA_settings(
Expand Down Expand Up @@ -49,7 +61,11 @@ IPA_settings(


"""
IPCrossA(settings)
Invariant Point Cross Attention (IPCrossA). Information flows from L (Keys, Values) to R (Queries).
Get `settings` with [`IPA_settings`](@ref)
"""
struct IPCrossA
settings::NamedTuple
Expand Down Expand Up @@ -149,7 +165,7 @@ function (ipa::Union{IPCrossA, IPA})(
else
use_softmax1 = false
end

rot_TiL, translate_TiL = TiL
rot_TiR, translate_TiR = TiR

Expand Down
20 changes: 15 additions & 5 deletions src/rotational_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,26 @@ Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N
to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N).
Note here that rotations here are represented in matrix form.
"""
function T_R3(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
return batched_mul(R, x) .+ t
function T_R3(x::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T
x′ = reshape(x, 3, size(x, 2), :)
R′ = reshape(R, 3, 3, :)
t′ = reshape(t, 3, 1, :)
y′ = batched_mul(R′, x′) .+ t′
y = reshape(y′, size(x))
return y
end

"""
"""
Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N batches of m points in R3,
such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x.
"""
function T_R3_inv(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
return batched_mul_T1(R, x .- t)
function T_R3_inv(y::AbstractArray{T}, R::AbstractArray{T}, t::AbstractArray{T}) where T
y′ = reshape(y, 3, size(y, 2), :)
R′ = reshape(R, 3, 3, :)
t′ = reshape(t, 3, 1, :)
x′ = batched_mul(batched_transpose(R′), y′ .- t′)
x = reshape(x′, size(y))
return x
end

"""
Expand Down

0 comments on commit 9d89985

Please sign in to comment.