From 79786c629fbe6ce010e4e43bf151d368977d811b Mon Sep 17 00:00:00 2001 From: anton083 Date: Sat, 8 Jun 2024 00:15:06 +0200 Subject: [PATCH] Make batched mul transpose work on GPU --- .github/workflows/CI.yml | 41 ++++++++++++++++++++++++++++------------ docs/make.jl | 3 +-- src/grads.jl | 14 +++++--------- src/rotational_utils.jl | 23 ++++++++++++++++++---- test/runtests.jl | 4 ++-- 5 files changed, 56 insertions(+), 29 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ed23f01..35c1088 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -5,6 +5,7 @@ on: - main tags: ['*'] pull_request: + workflow_dispatch: concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. @@ -14,10 +15,15 @@ jobs: test: name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} runs-on: ${{ matrix.os }} + timeout-minutes: 60 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read strategy: fail-fast: false matrix: version: + - '1.9' - '1.10' - 'nightly' os: @@ -25,35 +31,46 @@ jobs: arch: - x64 steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v2 + - uses: codecov/codecov-action@v3 with: files: lcov.info docs: name: Documentation runs-on: ubuntu-latest permissions: + actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created contents: write + statuses: write steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: '1' + - uses: julia-actions/cache@v2 + - name: Configure doc environment + shell: julia --project=docs --color=yes {0} + run: | + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate() - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - run: | - julia --project=docs -e ' - using Documenter: DocMeta, doctest - using InvariantPointAttention - DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPointAttention); recursive=true) - doctest(InvariantPointAttention)' + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} + - name: Run doctests + shell: julia --project=docs --color=yes {0} + run: | + using Documenter: DocMeta, doctest + using InvariantPointAttention + DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPointAttention); recursive=true) + doctest(InvariantPointAttention)' \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index b55da19..b15f92f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,7 +6,6 @@ DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPoi makedocs(; modules=[InvariantPointAttention], authors="Ben Murrell and contributors", - repo="https://github.com/murrellb/InvariantPointAttention.jl/blob/{commit}{path}#{line}", sitename="InvariantPointAttention.jl", format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", @@ -20,6 +19,6 @@ makedocs(; ) deploydocs(; - repo="github.com/murrellb/InvariantPointAttention.jl", + repo="github.com/MurrellGroup/InvariantPointAttention.jl", devbranch="main", ) diff --git a/src/grads.jl b/src/grads.jl index ead83ac..fdaf29e 100644 --- a/src/grads.jl +++ b/src/grads.jl @@ -48,8 +48,8 @@ end function ChainRulesCore.rrule(::typeof(T_R3), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N} function T_R3_pullback(_Δy) Δy = unthunk(_Δy) - Δx = @thunk(batched_mul(_batched_transpose(R), Δy)) - ΔR = @thunk(batched_mul(Δy, _batched_transpose(x))) + Δx = @thunk(batched_mul_T1(R, Δy)) + ΔR = @thunk(batched_mul_T2(Δy, x)) Δt = @thunk(sum(Δy, dims=2)) return (NoTangent(), Δx, ΔR, Δt) end @@ -58,15 +58,15 @@ end function ChainRulesCore.rrule(::typeof(T_R3_inv), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N} z = x .- t - y = batched_mul(_batched_transpose(R), z) + y = batched_mul_T1(R, z) function T_R3_inv_pullback(_Δy) Δy = unthunk(_Δy) Δx = @thunk(batched_mul(R, Δy)) - ΔR = @thunk(batched_mul(z, _batched_transpose(Δy))) + ΔR = @thunk(batched_mul_T2(z, Δy)) Δt = @thunk(-sum(Δx, dims=2)) # t is in the same position as x, but negated and broadcasted return (NoTangent(), Δx, ΔR, Δt) end - return T_R3_inv(x, R, t), T_R3_inv_pullback + return y, T_R3_inv_pullback end #= @@ -195,7 +195,3 @@ function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::Abst w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3))) end - -function test_version() - println("Hello World! gradablateaij") -end \ No newline at end of file diff --git a/src/rotational_utils.jl b/src/rotational_utils.jl index 85325cb..9cd72c5 100644 --- a/src/rotational_utils.jl +++ b/src/rotational_utils.jl @@ -54,11 +54,26 @@ Generates random translations of given size. get_translation(T::Type{<:Real}, dims...) = randn(T, 3, 1, dims...) get_translation(dims...; T::Type{<:Real}=Float32) = get_translation(T, dims...) -function _batched_transpose(data::A) where {T,N,A<:AbstractArray{T,N}} - perm = (2,1,3:N...) - PermutedDimsArray{T,N,perm,perm,A}(data) + +function batched_mul_T1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + batch_size = size(x)[3:end] + @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." + x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose + y2 = reshape(y, size(y, 1), size(y, 2), :) + z = batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), batch_size...) +end + +function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + batch_size = size(x)[3:end] + @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." + x2 = reshape(x, size(x, 1), size(x, 2), :) + y2 = reshape(y, size(y, 1), size(y, 2), :) |> batched_transpose + z = batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), batch_size...) end + """ Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N). @@ -73,7 +88,7 @@ Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x. """ function T_R3_inv(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N} - return batched_mul(_batched_transpose(R), x .- t) + return batched_mul_T1(R, x .- t) end """ diff --git a/test/runtests.jl b/test/runtests.jl index 5a1e1a7..e07da30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,7 +53,7 @@ using ChainRulesTestUtils test_rrule(pair_diff, x, y; fkwargs=(; dims=2)) end - #=@testset "ipa_customgrad" begin + @testset "ipa_customgrad" begin batch_size = 3 framesL = 10 framesR = 10 @@ -84,7 +84,7 @@ using ChainRulesTestUtils end #@show lz, lz2 @test abs.(lz - lz2) < 1f-5 - end=# + end @testset "IPAsoftmax_invariance" begin batch_size = 3