Skip to content

Commit

Permalink
AMDGPU support (#28)
Browse files Browse the repository at this point in the history
* add AMDGPU support

* remove patch noises

* use gpuarrays ext

* rework gemm interface

* update test for switching gpu backend

---------

Co-authored-by: Radu Diaconu <[email protected]>
  • Loading branch information
chengchingwen and radudiaconu0 authored May 1, 2024
1 parent 006a5be commit e7f47be
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 422 deletions.
17 changes: 13 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,40 @@ version = "0.2.13"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences="26cc04aa-876d-5657-8c51-4c34ba976000"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NeuralAttentionlibCUDAExt = "CUDA"
NeuralAttentionlibFiniteDifferences = "FiniteDifferences"
NeuralAttentionlibAMDGPUExt = ["AMDGPU", "GPUArrays"]
NeuralAttentionlibCUDAExt = ["CUDA", "GPUArrays"]
NeuralAttentionlibFiniteDifferencesExt = "FiniteDifferences"
NeuralAttentionlibGPUArraysExt = "GPUArrays"
NeuralAttentionlibZygoteExt = "Zygote"

[compat]
AMDGPU = "0.8"
Adapt = "3.3, 4"
CUDA = "5"
ChainRulesCore = "1.3"
FiniteDifferences = "0.12"
GPUArrays = "10"
GPUArraysCore = "0.1"
NNlib = "0.9"
Static = "0.7, 0.8"
Zygote = "0.6"
julia = "1.10"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -43,4 +52,4 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[targets]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "AMDGPU", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"]
56 changes: 56 additions & 0 deletions ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module NeuralAttentionlibAMDGPUExt

using NeuralAttentionlib
using NeuralAttentionlib.Adapt
using NeuralAttentionlib: TypedPtr, AbstractArrayMask, Indexer, GetIndexer
using AMDGPU

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

import AMDGPU.rocBLAS
for (fname, elty) in
((:rocblas_dgemm_strided_batched, :Float64),
(:rocblas_sgemm_strided_batched, :Float32),
(:rocblas_hgemm_strided_batched, :Float16),
(:rocblas_zgemm_strided_batched, :ComplexF64),
(:rocblas_cgemm_strided_batched, :ComplexF32))
@eval begin
@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Integer, n::Integer, k::Integer,
alpha::($elty), tptrA::TypedPtr{ROCArray, $elty}, lda::Integer, strideA::Integer,
tptrB::TypedPtr{ROCArray, $elty}, ldb::Integer, strideB::Integer, beta::($elty),
tptrC::TypedPtr{ROCArray, $elty}, ldc::Integer, strideC::Integer, batchCount::Integer)

ptrA = tptrA.ptr
ptrB = tptrB.ptr
ptrC = tptrC.ptr
rocBLAS.$fname(rocBLAS.handle(),
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
return nothing
end
end
end

NeuralAttentionlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray
NeuralAttentionlib.check_strided_gemm_type(A::ROCArray{Float16}) = true

Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::AbstractArrayMask) =
Indexer{typeof(m)}(map(Base.Fix1(Adapt.adapt, to), GetIndexer(m).__fields))
Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.FlipMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask),))
Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.CombinedMask) =
Indexer{typeof(m)}((f = adapt(to, m.f), masks = map(Base.Fix1(adapt, to), m.masks)))
Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.BatchedMask) =
Indexer{typeof(m)}((mask = adapt(to, m.mask), batch_dim = static(m.batch_dim)))
Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.RepeatMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask), num = m.num))
Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.BiSequenceMask) =
Indexer{typeof(m)}((q_mask = adapt(to, m.q_mask), k_mask = adapt(to, m.k_mask)))

end
227 changes: 9 additions & 218 deletions ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,15 @@ module NeuralAttentionlibCUDAExt

using NeuralAttentionlib
using NeuralAttentionlib.Adapt
using NeuralAttentionlib: AbstractArrayMask, Indexer, GetIndexer
using NeuralAttentionlib: TypedPtr, AbstractArrayMask, Indexer, GetIndexer
using CUDA
using CUDA.GPUArrays
using CUDA.GPUArrays.GPUArraysCore

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

GPUArraysCore.backend(T::Type{<:NAlib.CollapsedDimsArray{E, <:CuArray}}) where E = GPUArraysCore.backend(CuArray{E, 3})

function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
GPUArrays.gpu_call(B, A) do ctx, B, A
idx = GPUArrays.@cartesianidx A
@inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
return
end
return B
end

import CUDA.CUBLAS
for (fname, elty) in
((:cublasDgemmStridedBatched,:Float64),
Expand All @@ -33,222 +19,27 @@ for (fname, elty) in
(:cublasZgemmStridedBatched,:ComplexF64),
(:cublasCgemmStridedBatched,:ComplexF32))
@eval begin

@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::CuPtr{$elty}, lda::Int, strideA::Int,
ptrB::CuPtr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::CuPtr{$elty}, ldc::Int, strideC::Int, batchCount::Int)
m::Integer, n::Integer, k::Integer,
alpha::($elty), tptrA::TypedPtr{CuArray, $elty}, lda::Integer, strideA::Integer,
tptrB::TypedPtr{CuArray, $elty}, ldb::Integer, strideB::Integer, beta::($elty),
tptrC::TypedPtr{CuArray, $elty}, ldc::Integer, strideC::Integer, batchCount::Integer)

ptrA = tptrA.ptr
ptrB = tptrB.ptr
ptrC = tptrC.ptr
CUBLAS.$fname(CUBLAS.handle(),
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
return nothing
end

end
end

for (elty, array) in (
(:Float16, :CuArray),
)
@eval begin
@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::Ptr{$elty}, lda::Int, strideA::Int,
ptrB::Ptr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::Ptr{$elty}, ldc::Int, strideC::Int, batchCount::Int)

# https://github.com/FluxML/NNlib.jl/blob/cd3851d31e95020e77e67f80fb6402b5b87db1e6/src/gemm.jl#L91-L139
n_threads = min(Threads.nthreads(), 1 + max(m * k * batchCount, n * k * batchCount) ÷ 8000)
if n_threads > 1
old_threads = get_num_threads()
set_num_threads(1)
Threads.@sync for bs in Iterators.partition(1:batchCount, cld(batchCount, n_threads))
Threads.@spawn for b in bs
ptrAi = ptrA + (b - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (b - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (b - 1) * strideC * sizeof($elty)

NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
set_num_threads(old_threads)
else
for i = 1:batchCount
ptrAi = ptrA + (i - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (i - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (i - 1) * strideC * sizeof($elty)
NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
return nothing
end

@inline function NeuralAttentionlib.gemm_strided_batched_impl!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), A::$array{$elty}, lda::Int, strideA::Int,
B::$array{$elty}, ldb::Int, strideB::Int, beta::($elty),
C::$array{$elty}, ldc::Int, strideC::Int, batchCount::Int)

ptrA = pointer(A)
ptrB = pointer(B)
ptrC = pointer(C)
GC.@preserve A B C begin
NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
end
return C
end

@inline function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3}, beta::($elty),
C::$array{$elty, 3})

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"

m = size(A, transA == 'N' ? 1 : 2)
ka = size(A, transA == 'N' ? 2 : 1)
kb = size(B, transB == 'N' ? 1 : 2)
n = size(B, transB == 'N' ? 2 : 1)

if m != size(C,1) || n != size(C,2) || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$(size(A, 3))), B has size ($kb,$n,$(size(B, 3))), C has size $(size(C))"))
end

lda = max(1, stride(A,2))
ldb = max(1, stride(B,2))
ldc = max(1, stride(C,2))

strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(C, 3)

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
return NeuralAttentionlib.gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, 3},
B::$array{$elty, 3})
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B)
end

function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2}, beta::($elty),
C::$array{$elty, N3},
Ai, Aj, Bi, Bj, Ci, Cj) where {N1, N2, N3}

# (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an)
# |______lda______| |____K/M (Ai)_____| |___Aj____|
# (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn)
# |______ldb______| |____K/N (Bi)_____| |___Bj____|
# (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn)
# |______ldc______| |____K/N (Ci)_____| |___Cj____|

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)

sa1, sa2, sa3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj)
sb1, sb2, sb3 = NeuralAttentionlib.collapsed_size(B, Bi, Bj)
sc1, sc2, sc3 = NeuralAttentionlib.collapsed_size(C, Ci, Cj)

@assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C"
@assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C"

m = transA == 'N' ? sa1 : sa2
ka = transA == 'N' ? sa2 : sa1
kb = transB == 'N' ? sb1 : sb2
n = transB == 'N' ? sb2 : sb1

if m != sc1 || n != sc2 || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$sa3), B has size ($kb,$n,$sb3), C has size ($sc1, $sc2, $sc3)"))
end

lda = max(1, stride(A, N1 - Ai - Aj + 1))
ldb = max(1, stride(B, N2 - Bi - Bj + 1))
ldc = max(1, stride(C, N3 - Ci - Cj + 1))

strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1)
strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1)
strideC = stride(C, N3 - Cj + 1)
batchCount = sc3

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}

m = NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2)
n = NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1)
sc3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj, 3) > NeuralAttentionlib.collapsed_size(B, Bi, Bj, 3) ?
NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, 3) :
NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, 3)

Ci = length(n)
Cj = length(sc3)
C = similar(B, (m..., n..., sc3...))
return NeuralAttentionlib.gemm_strided_batched!(
transA, transB, alpha, A, B, zero($elty), C, Ai, Aj, Bi, Bj, Ci, Cj)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B, Ai, Aj, Bi, Bj)
end

end
end

NeuralAttentionlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray
NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true

Adapt.adapt(to::CUDA.KernelAdaptor, m::AbstractArrayMask) =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module NeuralAttentionlibFiniteDifferences
module NeuralAttentionlibFiniteDifferencesExt

using NeuralAttentionlib
using NeuralAttentionlib: CollapsedDimsArray, collapseddims
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module NeuralAttentionlibGPUArraysExt

using NeuralAttentionlib
using NeuralAttentionlib: CollapsedDimsArray
using GPUArrays
using GPUArrays.GPUArraysCore

GPUArraysCore.backend(::Type{<:CollapsedDimsArray{E, A}}) where {E, A} = GPUArraysCore.backend(A)

NeuralAttentionlib.ptrtypetag(arr::AnyGPUArray) = NeuralAttentionlib.ptrtypetag(GPUArraysCore.backend(arr))

function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
GPUArrays.gpu_call(B, A) do ctx, B, A
idx = GPUArrays.@cartesianidx A
@inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
return
end
return B
end

end
Loading

2 comments on commit e7f47be

@chengchingwen
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105935

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.13 -m "<description of version>" e7f47bef31b5b46aee18126fcbf87351e90c1520
git push origin v0.2.13

Please sign in to comment.