diff --git a/examples/cross_example.jl b/examples/cross_example.jl index bf55b92..eb82bed 100644 --- a/examples/cross_example.jl +++ b/examples/cross_example.jl @@ -1,9 +1,4 @@ -using LinearAlgebra -using StatsBase -using Flux - -include("../src/layers.jl") -include("../src/rotational_utils.jl") +using InvariantPointAttention len_L, len_R, dim, batch = 5, 5, 4, 10 T_L = (get_rotation(len_L, batch), randn(Float32, 3, len_L, batch)) diff --git a/examples/masking_example.jl b/examples/masking_example.jl index fed801b..5640762 100644 --- a/examples/masking_example.jl +++ b/examples/masking_example.jl @@ -1,9 +1,5 @@ -using LinearAlgebra -using StatsBase -using Flux - -include("../src/layers.jl") -include("../src/rotational_utils.jl") +using InvariantPointAttention +using InvariantPointAttention: get_rotation, get_translation N_frames = 7 dim = 32 diff --git a/examples/softmax1_invariance.jl b/examples/softmax1_invariance.jl index e266929..9d36dff 100644 --- a/examples/softmax1_invariance.jl +++ b/examples/softmax1_invariance.jl @@ -1,5 +1,3 @@ -using Pkg; Pkg.activate(".") -using Revise using InvariantPointAttention using InvariantPointAttention: get_rotation, get_translation, T_T, sumdrop diff --git a/examples/test_invariance.jl b/examples/test_invariance.jl index e76f2da..ad670be 100644 --- a/examples/test_invariance.jl +++ b/examples/test_invariance.jl @@ -1,9 +1,5 @@ -using LinearAlgebra -using StatsBase -using Flux - -include("../src/layers.jl") -include("../src/rotational_utils.jl") +using InvariantPointAttention +using InvariantPointAttention: get_rotation, get_translation, T_T batch_size = 32 frames = 64 diff --git a/src/masks.jl b/src/masks.jl index 2842284..569dd1f 100644 --- a/src/masks.jl +++ b/src/masks.jl @@ -52,7 +52,7 @@ function virtual_residues( Nr = size(S, 2) start = 1 if rand_start - start = sample(1:step) + start = rand(1:step) end vr = start:step:Nr S_virt = S[:,vr,:] diff --git a/test/runtests.jl b/test/runtests.jl index 061267a..ba6275a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,82 +9,6 @@ import Flux using ChainRulesTestUtils @testset "InvariantPointAttention.jl" begin - # Write your tests here. - @testset "Softmax1" begin - #Check if softmax1 is consistent with softmax, when adding an additional zero logit - x = randn(4,3) - xaug = hcat(x, zeros(4,1)) - @test softmax1(x, dims = 2) ≈ Flux.softmax(xaug, dims = 2)[:,1:end-1] - end - - @testset "softmax1 rrule" begin - x = randn(2,3,4) - - foreach(i -> test_rrule(softmax1, x; fkwargs=(; dims=i)), 1:3) - end - - @testset "T_R3 rrule" begin - x = randn(Float64, 3, 2, 1, 2) - R = get_rotation(Float64, 1, 2) - t = get_translation(Float64, 1, 2) - test_rrule(T_R3, x, R, t) - end - - @testset "T_R3_inv rrule" begin - x = randn(Float64, 3, 2, 1, 2) - R = get_rotation(Float64, 1, 2) - t = get_translation(Float64, 1, 2) - test_rrule(T_R3_inv, x, R, t) - end - - @testset "sumabs2 rrule" begin - x = rand(2,3,4) - foreach(i -> test_rrule(sumabs2, x; fkwargs=(; dims=i)), 1:3) - end - - @testset "L2norm rrule" begin - x = randn(2,3,4,5) - foreach(i -> test_rrule(L2norm, x; fkwargs=(; dims=i)), 1:3) - end - - @testset "pair_diff custom grad" begin - x = randn(1,4,2) - y = randn(1,3,2) - test_rrule(pair_diff, x, y; fkwargs=(; dims=2)) - end - - @testset "ipa_customgrad" begin - batch_size = 3 - framesL = 10 - framesR = 10 - dim = 10 - - siL = randn(Float32, dim, framesL, batch_size) - siR = siL - # Use CLOPS.jl shape notation - TiL = (get_rotation(Float32, framesL, batch_size), get_translation(Float32, framesL, batch_size)) - TiR = TiL - zij = randn(Float32, 16, framesR, framesL, batch_size) - - ipa = IPCrossA(IPA_settings(dim; use_softmax1 = true, c_z = 16, Typ = Float32)) - # Batching on mask - mask = right_to_left_mask(framesL)[:, :, ones(Int, batch_size)] - ps = Flux.params(ipa) - - lz,gs = Flux.withgradient(ps) do - sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = true)) - end - - lz2, zygotegs = Flux.withgradient(ps) do - sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = false)) - end - - for (gs, zygotegs) in zip(keys(gs),keys(zygotegs)) - @test maximum(abs.(gs .- zygotegs)) < 2f-5 - end - #@show lz, lz2 - @test abs.(lz - lz2) < 1f-5 - end @testset "IPAsoftmax_invariance" begin batch_size = 3 @@ -232,4 +156,81 @@ using ChainRulesTestUtils end @test cat(siRs..., dims = 2) ≈ ipa(TiL, siL, TiR, siR; zij, mask = right_to_left_mask(10)) end + + @testset "Softmax1" begin + #Check if softmax1 is consistent with softmax, when adding an additional zero logit + x = randn(4,3) + xaug = hcat(x, zeros(4,1)) + @test softmax1(x, dims = 2) ≈ Flux.softmax(xaug, dims = 2)[:,1:end-1] + end + + @testset "softmax1 rrule" begin + x = randn(2,3,4) + + foreach(i -> test_rrule(softmax1, x; fkwargs=(; dims=i)), 1:3) + end + + @testset "T_R3 rrule" begin + x = randn(Float64, 3, 2, 1, 2) + R = get_rotation(Float64, 1, 2) + t = get_translation(Float64, 1, 2) + test_rrule(T_R3, x, R, t) + end + + @testset "T_R3_inv rrule" begin + x = randn(Float64, 3, 2, 1, 2) + R = get_rotation(Float64, 1, 2) + t = get_translation(Float64, 1, 2) + test_rrule(T_R3_inv, x, R, t) + end + + @testset "sumabs2 rrule" begin + x = rand(2,3,4) + foreach(i -> test_rrule(sumabs2, x; fkwargs=(; dims=i)), 1:3) + end + + @testset "L2norm rrule" begin + x = randn(2,3,4,5) + foreach(i -> test_rrule(L2norm, x; fkwargs=(; dims=i)), 1:3) + end + + @testset "pair_diff rrule" begin + x = randn(1,4,2) + y = randn(1,3,2) + test_rrule(pair_diff, x, y; fkwargs=(; dims=2)) + end + + @testset "ipa_customgrad" begin + batch_size = 3 + framesL = 10 + framesR = 10 + dim = 10 + + siL = randn(Float32, dim, framesL, batch_size) + siR = siL + # Use CLOPS.jl shape notation + TiL = (get_rotation(Float32, framesL, batch_size), get_translation(Float32, framesL, batch_size)) + TiR = TiL + zij = randn(Float32, 16, framesR, framesL, batch_size) + + ipa = IPCrossA(IPA_settings(dim; use_softmax1 = true, c_z = 16, Typ = Float32)) + # Batching on mask + mask = right_to_left_mask(framesL)[:, :, ones(Int, batch_size)] + ps = Flux.params(ipa) + + lz,gs = Flux.withgradient(ps) do + sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = true)) + end + + lz2, zygotegs = Flux.withgradient(ps) do + sum(ipa(TiL, siL, TiR, siR; zij, mask, customgrad = false)) + end + + for (gs, zygotegs) in zip(keys(gs),keys(zygotegs)) + @test maximum(abs.(gs .- zygotegs)) < 2f-5 + end + #@show lz, lz2 + @test abs.(lz - lz2) < 1f-5 + end + end \ No newline at end of file