Skip to content

Commit

Permalink
Update examples, reorder tests, fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 11, 2024
1 parent 9ddd90f commit 1518c5d
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 97 deletions.
7 changes: 1 addition & 6 deletions examples/cross_example.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
using LinearAlgebra
using StatsBase
using Flux

include("../src/layers.jl")
include("../src/rotational_utils.jl")
using InvariantPointAttention

len_L, len_R, dim, batch = 5, 5, 4, 10
T_L = (get_rotation(len_L, batch), randn(Float32, 3, len_L, batch))
Expand Down
8 changes: 2 additions & 6 deletions examples/masking_example.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using LinearAlgebra
using StatsBase
using Flux

include("../src/layers.jl")
include("../src/rotational_utils.jl")
using InvariantPointAttention
using InvariantPointAttention: get_rotation, get_translation

N_frames = 7
dim = 32
Expand Down
2 changes: 0 additions & 2 deletions examples/softmax1_invariance.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Pkg; Pkg.activate(".")
using Revise
using InvariantPointAttention
using InvariantPointAttention: get_rotation, get_translation, T_T, sumdrop

Expand Down
8 changes: 2 additions & 6 deletions examples/test_invariance.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using LinearAlgebra
using StatsBase
using Flux

include("../src/layers.jl")
include("../src/rotational_utils.jl")
using InvariantPointAttention
using InvariantPointAttention: get_rotation, get_translation, T_T

batch_size = 32
frames = 64
Expand Down
2 changes: 1 addition & 1 deletion src/masks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function virtual_residues(
Nr = size(S, 2)
start = 1
if rand_start
start = sample(1:step)
start = rand(1:step)
end
vr = start:step:Nr
S_virt = S[:,vr,:]
Expand Down
153 changes: 77 additions & 76 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,6 @@ import Flux
using ChainRulesTestUtils

@testset "InvariantPointAttention.jl" begin
# Write your tests here.
@testset "Softmax1" begin
#Check if softmax1 is consistent with softmax, when adding an additional zero logit
x = randn(4,3)
xaug = hcat(x, zeros(4,1))
@test softmax1(x, dims = 2) Flux.softmax(xaug, dims = 2)[:,1:end-1]
end

@testset "softmax1 rrule" begin
x = randn(2,3,4)

foreach(i -> test_rrule(softmax1, x; fkwargs=(; dims=i)), 1:3)
end

@testset "T_R3 rrule" begin
x = randn(Float64, 3, 2, 1, 2)
R = get_rotation(Float64, 1, 2)
t = get_translation(Float64, 1, 2)
test_rrule(T_R3, x, R, t)
end

@testset "T_R3_inv rrule" begin
x = randn(Float64, 3, 2, 1, 2)
R = get_rotation(Float64, 1, 2)
t = get_translation(Float64, 1, 2)
test_rrule(T_R3_inv, x, R, t)
end

@testset "sumabs2 rrule" begin
x = rand(2,3,4)
foreach(i -> test_rrule(sumabs2, x; fkwargs=(; dims=i)), 1:3)
end

@testset "L2norm rrule" begin
x = randn(2,3,4,5)
foreach(i -> test_rrule(L2norm, x; fkwargs=(; dims=i)), 1:3)
end

@testset "pair_diff custom grad" begin
x = randn(1,4,2)
y = randn(1,3,2)
test_rrule(pair_diff, x, y; fkwargs=(; dims=2))
end

@testset "ipa_customgrad" begin
batch_size = 3
framesL = 10
framesR = 10
dim = 10

siL = randn(Float32, dim, framesL, batch_size)
siR = siL
# Use CLOPS.jl shape notation
TiL = (get_rotation(Float32, framesL, batch_size), get_translation(Float32, framesL, batch_size))
TiR = TiL
zij = randn(Float32, 16, framesR, framesL, batch_size)

ipa = IPCrossA(IPA_settings(dim; use_softmax1 = true, c_z = 16, Typ = Float32))
# Batching on mask
mask = right_to_left_mask(framesL)[:, :, ones(Int, batch_size)]
ps = Flux.params(ipa)

lz,gs = Flux.withgradient(ps) do
sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = true))
end

lz2, zygotegs = Flux.withgradient(ps) do
sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = false))
end

for (gs, zygotegs) in zip(keys(gs),keys(zygotegs))
@test maximum(abs.(gs .- zygotegs)) < 2f-5
end
#@show lz, lz2
@test abs.(lz - lz2) < 1f-5
end

@testset "IPAsoftmax_invariance" begin
batch_size = 3
Expand Down Expand Up @@ -232,4 +156,81 @@ using ChainRulesTestUtils
end
@test cat(siRs..., dims = 2) ipa(TiL, siL, TiR, siR; zij, mask = right_to_left_mask(10))
end

@testset "Softmax1" begin
#Check if softmax1 is consistent with softmax, when adding an additional zero logit
x = randn(4,3)
xaug = hcat(x, zeros(4,1))
@test softmax1(x, dims = 2) Flux.softmax(xaug, dims = 2)[:,1:end-1]
end

@testset "softmax1 rrule" begin
x = randn(2,3,4)

foreach(i -> test_rrule(softmax1, x; fkwargs=(; dims=i)), 1:3)
end

@testset "T_R3 rrule" begin
x = randn(Float64, 3, 2, 1, 2)
R = get_rotation(Float64, 1, 2)
t = get_translation(Float64, 1, 2)
test_rrule(T_R3, x, R, t)
end

@testset "T_R3_inv rrule" begin
x = randn(Float64, 3, 2, 1, 2)
R = get_rotation(Float64, 1, 2)
t = get_translation(Float64, 1, 2)
test_rrule(T_R3_inv, x, R, t)
end

@testset "sumabs2 rrule" begin
x = rand(2,3,4)
foreach(i -> test_rrule(sumabs2, x; fkwargs=(; dims=i)), 1:3)
end

@testset "L2norm rrule" begin
x = randn(2,3,4,5)
foreach(i -> test_rrule(L2norm, x; fkwargs=(; dims=i)), 1:3)
end

@testset "pair_diff rrule" begin
x = randn(1,4,2)
y = randn(1,3,2)
test_rrule(pair_diff, x, y; fkwargs=(; dims=2))
end

@testset "ipa_customgrad" begin
batch_size = 3
framesL = 10
framesR = 10
dim = 10

siL = randn(Float32, dim, framesL, batch_size)
siR = siL
# Use CLOPS.jl shape notation
TiL = (get_rotation(Float32, framesL, batch_size), get_translation(Float32, framesL, batch_size))
TiR = TiL
zij = randn(Float32, 16, framesR, framesL, batch_size)

ipa = IPCrossA(IPA_settings(dim; use_softmax1 = true, c_z = 16, Typ = Float32))
# Batching on mask
mask = right_to_left_mask(framesL)[:, :, ones(Int, batch_size)]
ps = Flux.params(ipa)

lz,gs = Flux.withgradient(ps) do
sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = true))
end

lz2, zygotegs = Flux.withgradient(ps) do
sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = false))
end

for (gs, zygotegs) in zip(keys(gs),keys(zygotegs))
@test maximum(abs.(gs .- zygotegs)) < 2f-5
end
#@show lz, lz2
@test abs.(lz - lz2) < 1f-5
end

end

0 comments on commit 1518c5d

Please sign in to comment.