Skip to content

Commit

Permalink
Make batched mul transpose work on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Jun 7, 2024
1 parent 2559937 commit 79786c6
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 29 deletions.
41 changes: 29 additions & 12 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
- main
tags: ['*']
pull_request:
workflow_dispatch:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
Expand All @@ -14,46 +15,62 @@ jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
actions: write
contents: read
strategy:
fail-fast: false
matrix:
version:
- '1.9'
- '1.10'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v2
- uses: codecov/codecov-action@v3
with:
files: lcov.info
docs:
name: Documentation
runs-on: ubuntu-latest
permissions:
actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created
contents: write
statuses: write
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
- uses: julia-actions/cache@v2
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: |
julia --project=docs -e '
using Documenter: DocMeta, doctest
using InvariantPointAttention
DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPointAttention); recursive=true)
doctest(InvariantPointAttention)'
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
using Documenter: DocMeta, doctest
using InvariantPointAttention
DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPointAttention); recursive=true)
doctest(InvariantPointAttention)'
3 changes: 1 addition & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ DocMeta.setdocmeta!(InvariantPointAttention, :DocTestSetup, :(using InvariantPoi
makedocs(;
modules=[InvariantPointAttention],
authors="Ben Murrell <[email protected]> and contributors",
repo="https://github.com/murrellb/InvariantPointAttention.jl/blob/{commit}{path}#{line}",
sitename="InvariantPointAttention.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
Expand All @@ -20,6 +19,6 @@ makedocs(;
)

deploydocs(;
repo="github.com/murrellb/InvariantPointAttention.jl",
repo="github.com/MurrellGroup/InvariantPointAttention.jl",
devbranch="main",
)
14 changes: 5 additions & 9 deletions src/grads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ end
function ChainRulesCore.rrule(::typeof(T_R3), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
function T_R3_pullback(_Δy)
Δy = unthunk(_Δy)
Δx = @thunk(batched_mul(_batched_transpose(R), Δy))
ΔR = @thunk(batched_mul(Δy, _batched_transpose(x)))
Δx = @thunk(batched_mul_T1(R, Δy))
ΔR = @thunk(batched_mul_T2(Δy, x))
Δt = @thunk(sum(Δy, dims=2))
return (NoTangent(), Δx, ΔR, Δt)
end
Expand All @@ -58,15 +58,15 @@ end

function ChainRulesCore.rrule(::typeof(T_R3_inv), x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
z = x .- t
y = batched_mul(_batched_transpose(R), z)
y = batched_mul_T1(R, z)
function T_R3_inv_pullback(_Δy)
Δy = unthunk(_Δy)
Δx = @thunk(batched_mul(R, Δy))
ΔR = @thunk(batched_mul(z, _batched_transpose(Δy)))
ΔR = @thunk(batched_mul_T2(z, Δy))
Δt = @thunk(-sum(Δx, dims=2)) # t is in the same position as x, but negated and broadcasted
return (NoTangent(), Δx, ΔR, Δt)
end
return T_R3_inv(x, R, t), T_R3_inv_pullback
return y, T_R3_inv_pullback
end

#=
Expand Down Expand Up @@ -195,7 +195,3 @@ function pre_softmax_aijh(qh::AbstractArray{T},kh::AbstractArray{T},Ti,qhp::Abst

w_L.*(dim_scale.*qhTkh(qh,kh) .+ bij .- w_C/2 .* gamma_h .* dropdims(diff_sum_glob(Ti,qhp,khp),dims=(1,3)))
end

function test_version()
println("Hello World! gradablateaij")
end
23 changes: 19 additions & 4 deletions src/rotational_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,26 @@ Generates random translations of given size.
get_translation(T::Type{<:Real}, dims...) = randn(T, 3, 1, dims...)
get_translation(dims...; T::Type{<:Real}=Float32) = get_translation(T, dims...)

function _batched_transpose(data::A) where {T,N,A<:AbstractArray{T,N}}
perm = (2,1,3:N...)
PermutedDimsArray{T,N,perm,perm,A}(data)

function batched_mul_T1(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :) |> batched_transpose
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

function batched_mul_T2(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :) |> batched_transpose
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end


"""
Applies the SE3 transformations T = (rot,trans) ∈ SE(3)^N
to N batches of m points in R3, i.e., mat ∈ R^(3 x m x N) ↦ T(mat) ∈ R^(3 x m x N).
Expand All @@ -73,7 +88,7 @@ Applies the group inverse of the SE3 transformations T = (R,t) ∈ SE(3)^N to N
such that T^-1(T*x) = T^-1(Rx+t) = R^T(Rx+t-t) = x.
"""
function T_R3_inv(x::AbstractArray{T,N}, R::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T,N}
return batched_mul(_batched_transpose(R), x .- t)
return batched_mul_T1(R, x .- t)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ using ChainRulesTestUtils
test_rrule(pair_diff, x, y; fkwargs=(; dims=2))
end

#=@testset "ipa_customgrad" begin
@testset "ipa_customgrad" begin
batch_size = 3
framesL = 10
framesR = 10
Expand Down Expand Up @@ -84,7 +84,7 @@ using ChainRulesTestUtils
end
#@show lz, lz2
@test abs.(lz - lz2) < 1f-5
end=#
end

@testset "IPAsoftmax_invariance" begin
batch_size = 3
Expand Down

0 comments on commit 79786c6

Please sign in to comment.