diff --git a/Project.toml b/Project.toml index 019788a..d90d237 100644 --- a/Project.toml +++ b/Project.toml @@ -6,31 +6,40 @@ version = "0.2.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -FiniteDifferences="26cc04aa-876d-5657-8c51-4c34ba976000" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] -NeuralAttentionlibCUDAExt = "CUDA" -NeuralAttentionlibFiniteDifferences = "FiniteDifferences" +NeuralAttentionlibAMDGPUExt = ["AMDGPU", "GPUArrays"] +NeuralAttentionlibCUDAExt = ["CUDA", "GPUArrays"] +NeuralAttentionlibFiniteDifferencesExt = "FiniteDifferences" +NeuralAttentionlibGPUArraysExt = "GPUArrays" NeuralAttentionlibZygoteExt = "Zygote" [compat] +AMDGPU = "0.8" Adapt = "3.3, 4" CUDA = "5" ChainRulesCore = "1.3" FiniteDifferences = "0.12" +GPUArrays = "10" +GPUArraysCore = "0.1" NNlib = "0.9" Static = "0.7, 0.8" Zygote = "0.6" julia = "1.10" [extras] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -43,4 +52,4 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [targets] -test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"] +test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "AMDGPU", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"] diff --git a/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl b/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl new file mode 100644 index 0000000..feddc65 --- /dev/null +++ b/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl @@ -0,0 +1,56 @@ +module NeuralAttentionlibAMDGPUExt + +using NeuralAttentionlib +using NeuralAttentionlib.Adapt +using NeuralAttentionlib: TypedPtr, AbstractArrayMask, Indexer, GetIndexer +using AMDGPU + +import LinearAlgebra +import LinearAlgebra.BLAS +using LinearAlgebra.BLAS: get_num_threads, set_num_threads + +const NAlib = NeuralAttentionlib + +import AMDGPU.rocBLAS +for (fname, elty) in + ((:rocblas_dgemm_strided_batched, :Float64), + (:rocblas_sgemm_strided_batched, :Float32), + (:rocblas_hgemm_strided_batched, :Float16), + (:rocblas_zgemm_strided_batched, :ComplexF64), + (:rocblas_cgemm_strided_batched, :ComplexF32)) + @eval begin + @inline function NeuralAttentionlib.unsafe_gemm_strided_batched!( + transA::Char, transB::Char, + m::Integer, n::Integer, k::Integer, + alpha::($elty), tptrA::TypedPtr{ROCArray, $elty}, lda::Integer, strideA::Integer, + tptrB::TypedPtr{ROCArray, $elty}, ldb::Integer, strideB::Integer, beta::($elty), + tptrC::TypedPtr{ROCArray, $elty}, ldc::Integer, strideC::Integer, batchCount::Integer) + + ptrA = tptrA.ptr + ptrB = tptrB.ptr + ptrC = tptrC.ptr + rocBLAS.$fname(rocBLAS.handle(), + transA, transB, m, n, k, + alpha, ptrA, lda, strideA, + ptrB, ldb, strideB, beta, + ptrC, ldc, strideC, batchCount) + return nothing + end + end +end + +NeuralAttentionlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray +NeuralAttentionlib.check_strided_gemm_type(A::ROCArray{Float16}) = true + +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::AbstractArrayMask) = + Indexer{typeof(m)}(map(Base.Fix1(Adapt.adapt, to), GetIndexer(m).__fields)) +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.FlipMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask),)) +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.CombinedMask) = + Indexer{typeof(m)}((f = adapt(to, m.f), masks = map(Base.Fix1(adapt, to), m.masks))) +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.BatchedMask) = + Indexer{typeof(m)}((mask = adapt(to, m.mask), batch_dim = static(m.batch_dim))) +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.RepeatMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask), num = m.num)) +Adapt.adapt(to::AMDGPU.Runtime.Adaptor, m::NAlib.BiSequenceMask) = + Indexer{typeof(m)}((q_mask = adapt(to, m.q_mask), k_mask = adapt(to, m.k_mask))) + +end diff --git a/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl b/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl index 22b01b7..9975c78 100644 --- a/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl +++ b/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl @@ -2,10 +2,8 @@ module NeuralAttentionlibCUDAExt using NeuralAttentionlib using NeuralAttentionlib.Adapt -using NeuralAttentionlib: AbstractArrayMask, Indexer, GetIndexer +using NeuralAttentionlib: TypedPtr, AbstractArrayMask, Indexer, GetIndexer using CUDA -using CUDA.GPUArrays -using CUDA.GPUArrays.GPUArraysCore import LinearAlgebra import LinearAlgebra.BLAS @@ -13,18 +11,6 @@ using LinearAlgebra.BLAS: get_num_threads, set_num_threads const NAlib = NeuralAttentionlib -GPUArraysCore.backend(T::Type{<:NAlib.CollapsedDimsArray{E, <:CuArray}}) where E = GPUArraysCore.backend(CuArray{E, 3}) - -function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T - axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f))) - GPUArrays.gpu_call(B, A) do ctx, B, A - idx = GPUArrays.@cartesianidx A - @inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]]) - return - end - return B -end - import CUDA.CUBLAS for (fname, elty) in ((:cublasDgemmStridedBatched,:Float64), @@ -33,14 +19,16 @@ for (fname, elty) in (:cublasZgemmStridedBatched,:ComplexF64), (:cublasCgemmStridedBatched,:ComplexF32)) @eval begin - @inline function NeuralAttentionlib.unsafe_gemm_strided_batched!( transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), ptrA::CuPtr{$elty}, lda::Int, strideA::Int, - ptrB::CuPtr{$elty}, ldb::Int, strideB::Int, beta::($elty), - ptrC::CuPtr{$elty}, ldc::Int, strideC::Int, batchCount::Int) + m::Integer, n::Integer, k::Integer, + alpha::($elty), tptrA::TypedPtr{CuArray, $elty}, lda::Integer, strideA::Integer, + tptrB::TypedPtr{CuArray, $elty}, ldb::Integer, strideB::Integer, beta::($elty), + tptrC::TypedPtr{CuArray, $elty}, ldc::Integer, strideC::Integer, batchCount::Integer) + ptrA = tptrA.ptr + ptrB = tptrB.ptr + ptrC = tptrC.ptr CUBLAS.$fname(CUBLAS.handle(), transA, transB, m, n, k, alpha, ptrA, lda, strideA, @@ -48,207 +36,10 @@ for (fname, elty) in ptrC, ldc, strideC, batchCount) return nothing end - - end -end - -for (elty, array) in ( - (:Float16, :CuArray), -) - @eval begin - @inline function NeuralAttentionlib.unsafe_gemm_strided_batched!( - transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), ptrA::Ptr{$elty}, lda::Int, strideA::Int, - ptrB::Ptr{$elty}, ldb::Int, strideB::Int, beta::($elty), - ptrC::Ptr{$elty}, ldc::Int, strideC::Int, batchCount::Int) - - # https://github.com/FluxML/NNlib.jl/blob/cd3851d31e95020e77e67f80fb6402b5b87db1e6/src/gemm.jl#L91-L139 - n_threads = min(Threads.nthreads(), 1 + max(m * k * batchCount, n * k * batchCount) ÷ 8000) - if n_threads > 1 - old_threads = get_num_threads() - set_num_threads(1) - Threads.@sync for bs in Iterators.partition(1:batchCount, cld(batchCount, n_threads)) - Threads.@spawn for b in bs - ptrAi = ptrA + (b - 1) * strideA * sizeof($elty) - ptrBi = ptrB + (b - 1) * strideB * sizeof($elty) - ptrCi = ptrC + (b - 1) * strideC * sizeof($elty) - - NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k, - alpha, ptrAi, lda, - ptrBi, ldb, beta, - ptrCi, ldc) - end - end - set_num_threads(old_threads) - else - for i = 1:batchCount - ptrAi = ptrA + (i - 1) * strideA * sizeof($elty) - ptrBi = ptrB + (i - 1) * strideB * sizeof($elty) - ptrCi = ptrC + (i - 1) * strideC * sizeof($elty) - NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k, - alpha, ptrAi, lda, - ptrBi, ldb, beta, - ptrCi, ldc) - end - end - return nothing - end - - @inline function NeuralAttentionlib.gemm_strided_batched_impl!( - transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), A::$array{$elty}, lda::Int, strideA::Int, - B::$array{$elty}, ldb::Int, strideB::Int, beta::($elty), - C::$array{$elty}, ldc::Int, strideC::Int, batchCount::Int) - - ptrA = pointer(A) - ptrB = pointer(B) - ptrC = pointer(C) - GC.@preserve A B C begin - NeuralAttentionlib.unsafe_gemm_strided_batched!( - transA, transB, m, n, k, - alpha, ptrA, lda, strideA, - ptrB, ldb, strideB, beta, - ptrC, ldc, strideC, batchCount) - end - return C - end - - @inline function NeuralAttentionlib.gemm_strided_batched!( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, 3}, - B::$array{$elty, 3}, beta::($elty), - C::$array{$elty, 3}) - - Base.require_one_based_indexing(A, B, C) - BLAS.chkstride1(A, B, C) - @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" - @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" - - m = size(A, transA == 'N' ? 1 : 2) - ka = size(A, transA == 'N' ? 2 : 1) - kb = size(B, transB == 'N' ? 1 : 2) - n = size(B, transB == 'N' ? 2 : 1) - - if m != size(C,1) || n != size(C,2) || ka != kb - throw(DimensionMismatch("A has size ($m,$ka,$(size(A, 3))), B has size ($kb,$n,$(size(B, 3))), C has size $(size(C))")) - end - - lda = max(1, stride(A,2)) - ldb = max(1, stride(B,2)) - ldc = max(1, stride(C,2)) - - strideA = size(A, 3) == 1 ? 0 : stride(A, 3) - strideB = size(B, 3) == 1 ? 0 : stride(B, 3) - strideC = stride(C, 3) - batchCount = size(C, 3) - - NeuralAttentionlib.gemm_strided_batched_impl!( - transA, transB, m, n, ka, - alpha, A, lda, strideA, - B, ldb, strideB, beta, - C, ldc, strideC, batchCount) - - return C - end - - function NeuralAttentionlib.gemm_strided_batched( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, 3}, - B::$array{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) - return NeuralAttentionlib.gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C) - end - - function NeuralAttentionlib.gemm_strided_batched( - transA::Char, transB::Char, - A::$array{$elty, 3}, - B::$array{$elty, 3}) - return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B) - end - - function NeuralAttentionlib.gemm_strided_batched!( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, N1}, - B::$array{$elty, N2}, beta::($elty), - C::$array{$elty, N3}, - Ai, Aj, Bi, Bj, Ci, Cj) where {N1, N2, N3} - - # (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an) - # |______lda______| |____K/M (Ai)_____| |___Aj____| - # (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn) - # |______ldb______| |____K/N (Bi)_____| |___Bj____| - # (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn) - # |______ldc______| |____K/N (Ci)_____| |___Cj____| - - Base.require_one_based_indexing(A, B, C) - BLAS.chkstride1(A, B, C) - - sa1, sa2, sa3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj) - sb1, sb2, sb3 = NeuralAttentionlib.collapsed_size(B, Bi, Bj) - sc1, sc2, sc3 = NeuralAttentionlib.collapsed_size(C, Ci, Cj) - - @assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C" - @assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C" - - m = transA == 'N' ? sa1 : sa2 - ka = transA == 'N' ? sa2 : sa1 - kb = transB == 'N' ? sb1 : sb2 - n = transB == 'N' ? sb2 : sb1 - - if m != sc1 || n != sc2 || ka != kb - throw(DimensionMismatch("A has size ($m,$ka,$sa3), B has size ($kb,$n,$sb3), C has size ($sc1, $sc2, $sc3)")) - end - - lda = max(1, stride(A, N1 - Ai - Aj + 1)) - ldb = max(1, stride(B, N2 - Bi - Bj + 1)) - ldc = max(1, stride(C, N3 - Ci - Cj + 1)) - - strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1) - strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1) - strideC = stride(C, N3 - Cj + 1) - batchCount = sc3 - - NeuralAttentionlib.gemm_strided_batched_impl!( - transA, transB, m, n, ka, - alpha, A, lda, strideA, - B, ldb, strideB, beta, - C, ldc, strideC, batchCount) - - return C - end - - function NeuralAttentionlib.gemm_strided_batched( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, N1}, - B::$array{$elty, N2}, - Ai, Aj, Bi, Bj) where {N1, N2} - - m = NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2) - n = NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1) - sc3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj, 3) > NeuralAttentionlib.collapsed_size(B, Bi, Bj, 3) ? - NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, 3) : - NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, 3) - - Ci = length(n) - Cj = length(sc3) - C = similar(B, (m..., n..., sc3...)) - return NeuralAttentionlib.gemm_strided_batched!( - transA, transB, alpha, A, B, zero($elty), C, Ai, Aj, Bi, Bj, Ci, Cj) - end - - function NeuralAttentionlib.gemm_strided_batched( - transA::Char, transB::Char, - A::$array{$elty, N1}, - B::$array{$elty, N2}, - Ai, Aj, Bi, Bj) where {N1, N2} - return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B, Ai, Aj, Bi, Bj) - end - end end +NeuralAttentionlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true Adapt.adapt(to::CUDA.KernelAdaptor, m::AbstractArrayMask) = diff --git a/ext/NeuralAttentionlibFiniteDifferences/NeuralAttentionlibFiniteDifferences.jl b/ext/NeuralAttentionlibFiniteDifferencesExt/NeuralAttentionlibFiniteDifferencesExt.jl similarity index 90% rename from ext/NeuralAttentionlibFiniteDifferences/NeuralAttentionlibFiniteDifferences.jl rename to ext/NeuralAttentionlibFiniteDifferencesExt/NeuralAttentionlibFiniteDifferencesExt.jl index e6bb087..0902d3d 100644 --- a/ext/NeuralAttentionlibFiniteDifferences/NeuralAttentionlibFiniteDifferences.jl +++ b/ext/NeuralAttentionlibFiniteDifferencesExt/NeuralAttentionlibFiniteDifferencesExt.jl @@ -1,4 +1,4 @@ -module NeuralAttentionlibFiniteDifferences +module NeuralAttentionlibFiniteDifferencesExt using NeuralAttentionlib using NeuralAttentionlib: CollapsedDimsArray, collapseddims diff --git a/ext/NeuralAttentionlibGPUArraysExt/NeuralAttentionlibGPUArraysExt.jl b/ext/NeuralAttentionlibGPUArraysExt/NeuralAttentionlibGPUArraysExt.jl new file mode 100644 index 0000000..1945ab8 --- /dev/null +++ b/ext/NeuralAttentionlibGPUArraysExt/NeuralAttentionlibGPUArraysExt.jl @@ -0,0 +1,22 @@ +module NeuralAttentionlibGPUArraysExt + +using NeuralAttentionlib +using NeuralAttentionlib: CollapsedDimsArray +using GPUArrays +using GPUArrays.GPUArraysCore + +GPUArraysCore.backend(::Type{<:CollapsedDimsArray{E, A}}) where {E, A} = GPUArraysCore.backend(A) + +NeuralAttentionlib.ptrtypetag(arr::AnyGPUArray) = NeuralAttentionlib.ptrtypetag(GPUArraysCore.backend(arr)) + +function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T + axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f))) + GPUArrays.gpu_call(B, A) do ctx, B, A + idx = GPUArrays.@cartesianidx A + @inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]]) + return + end + return B +end + +end diff --git a/src/matmul/gemm.jl b/src/matmul/gemm.jl index 1b33292..8fac14b 100644 --- a/src/matmul/gemm.jl +++ b/src/matmul/gemm.jl @@ -6,16 +6,45 @@ using LinearAlgebra.BLAS: get_num_threads, set_num_threads const libblas = Base.libblas_name +struct TypedPtr{T, ET, P} + ptr::P + TypedPtr{T}(ptr) where T = new{T, eltype(ptr), typeof(ptr)}(ptr) +end +@inline typedpointer(arr::AbstractArray) = TypedPtr{ptrtypetag(arr)}(pointer(arr)) +@inline ptrtypetag(arr) = Array + +@inline function gemm_strided_batched_impl!( + transA::Char, transB::Char, + m::Integer, n::Integer, k::Integer, + alpha::ET, A::T1, lda::Integer, strideA::Integer, + B::T2, ldb::Integer, strideB::Integer, beta::ET, + C::T3, ldc::Integer, strideC::Integer, batchCount::Integer +) where {ET, N1, N2, N3, T1 <: AbstractArray{ET, N1}, T2 <: AbstractArray{ET, N2}, T3 <: AbstractArray{ET, N3}} + ptrA = typedpointer(A) + ptrB = typedpointer(B) + ptrC = typedpointer(C) + GC.@preserve A B C begin + unsafe_gemm_strided_batched!( + transA, transB, m, n, k, + alpha, ptrA, lda, strideA, + ptrB, ldb, strideB, beta, + ptrC, ldc, strideC, batchCount) + end + return C +end + for (gemm, elty) in NNlib.gemm_datatype_mappings @eval begin - @inline function unsafe_gemm!( transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), ptrA::Ptr{$elty}, lda::Int, - ptrB::Ptr{$elty}, ldb::Int, beta::($elty), - ptrC::Ptr{$elty}, ldc::Int) - + m::Integer, n::Integer, k::Integer, + alpha::($elty), tptrA::TypedPtr{Array, $elty}, lda::Integer, + tptrB::TypedPtr{Array, $elty}, ldb::Integer, beta::($elty), + tptrC::TypedPtr{Array, $elty}, ldc::Integer) + + ptrA = tptrA.ptr + ptrB = tptrB.ptr + ptrC = tptrC.ptr ccall((BLAS.@blasfunc($gemm), libblas), Nothing, (Ref{UInt8}, Ref{UInt8}, Ref{BLAS.BlasInt}, Ref{BLAS.BlasInt}, Ref{BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{BLAS.BlasInt}, @@ -27,21 +56,17 @@ for (gemm, elty) in NNlib.gemm_datatype_mappings ptrC, ldc) return nothing end - end -end -for (elty, array) in ( - (:ComplexF64, :AbstractArray), (:ComplexF32, :AbstractArray), - (:Float64, :AbstractArray), (:Float32, :AbstractArray), -) - @eval begin @inline function unsafe_gemm_strided_batched!( transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), ptrA::Ptr{$elty}, lda::Int, strideA::Int, - ptrB::Ptr{$elty}, ldb::Int, strideB::Int, beta::($elty), - ptrC::Ptr{$elty}, ldc::Int, strideC::Int, batchCount::Int) - + m::Integer, n::Integer, k::Integer, + alpha::($elty), tptrA::TypedPtr{Array, $elty}, lda::Integer, strideA::Integer, + tptrB::TypedPtr{Array, $elty}, ldb::Integer, strideB::Integer, beta::($elty), + tptrC::TypedPtr{Array, $elty}, ldc::Integer, strideC::Integer, batchCount::Integer) + + ptrA = tptrA.ptr + ptrB = tptrB.ptr + ptrC = tptrC.ptr # https://github.com/FluxML/NNlib.jl/blob/cd3851d31e95020e77e67f80fb6402b5b87db1e6/src/gemm.jl#L91-L139 n_threads = min(Threads.nthreads(), 1 + max(m * k * batchCount, n * k * batchCount) ÷ 8000) if n_threads > 1 @@ -52,12 +77,14 @@ for (elty, array) in ( ptrAi = ptrA + (b - 1) * strideA * sizeof($elty) ptrBi = ptrB + (b - 1) * strideB * sizeof($elty) ptrCi = ptrC + (b - 1) * strideC * sizeof($elty) + tptrAi = TypedPtr{Array}(ptrAi) + tptrBi = TypedPtr{Array}(ptrBi) + tptrCi = TypedPtr{Array}(ptrCi) unsafe_gemm!(transA, transB, m, n, k, - alpha, ptrAi, lda, - ptrBi, ldb, beta, - ptrCi, ldc) - + alpha, tptrAi, lda, + tptrBi, ldb, beta, + tptrCi, ldc) end end set_num_threads(old_threads) @@ -66,167 +93,91 @@ for (elty, array) in ( ptrAi = ptrA + (i - 1) * strideA * sizeof($elty) ptrBi = ptrB + (i - 1) * strideB * sizeof($elty) ptrCi = ptrC + (i - 1) * strideC * sizeof($elty) + tptrAi = TypedPtr{Array}(ptrAi) + tptrBi = TypedPtr{Array}(ptrBi) + tptrCi = TypedPtr{Array}(ptrCi) unsafe_gemm!(transA, transB, m, n, k, - alpha, ptrAi, lda, - ptrBi, ldb, beta, - ptrCi, ldc) - + alpha, tptrAi, lda, + tptrBi, ldb, beta, + tptrCi, ldc) end end return nothing end + end +end - @inline function gemm_strided_batched_impl!( - transA::Char, transB::Char, - m::Int, n::Int, k::Int, - alpha::($elty), A::$array{$elty}, lda::Int, strideA::Int, - B::$array{$elty}, ldb::Int, strideB::Int, beta::($elty), - C::$array{$elty}, ldc::Int, strideC::Int, batchCount::Int) - - ptrA = pointer(A) - ptrB = pointer(B) - ptrC = pointer(C) - - GC.@preserve A B C begin - unsafe_gemm_strided_batched!( - transA, transB, m, n, k, - alpha, ptrA, lda, strideA, - ptrB, ldb, strideB, beta, - ptrC, ldc, strideC, batchCount) - end - return C - end - - @inline function gemm_strided_batched!( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, 3}, - B::$array{$elty, 3}, beta::($elty), - C::$array{$elty, 3}) - - Base.require_one_based_indexing(A, B, C) - BLAS.chkstride1(A, B, C) - @assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C" - @assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C" - - m = size(A, transA == 'N' ? 1 : 2) - ka = size(A, transA == 'N' ? 2 : 1) - kb = size(B, transB == 'N' ? 1 : 2) - n = size(B, transB == 'N' ? 2 : 1) - - if m != size(C,1) || n != size(C,2) || ka != kb - throw(DimensionMismatch("A has size ($m,$ka,$(size(A, 3))), B has size ($kb,$n,$(size(B, 3))), C has size $(size(C))")) - end - - lda = max(1, stride(A,2)) - ldb = max(1, stride(B,2)) - ldc = max(1, stride(C,2)) - - strideA = size(A, 3) == 1 ? 0 : stride(A, 3) - strideB = size(B, 3) == 1 ? 0 : stride(B, 3) - strideC = stride(C, 3) - batchCount = size(C, 3) - - gemm_strided_batched_impl!( - transA, transB, m, n, ka, - alpha, A, lda, strideA, - B, ldb, strideB, beta, - C, ldc, strideC, batchCount) - - return C - end - - function gemm_strided_batched( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, 3}, - B::$array{$elty, 3}) - C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3)))) - return gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C) - end - - function gemm_strided_batched( - transA::Char, transB::Char, - A::$array{$elty, 3}, - B::$array{$elty, 3}) - return gemm_strided_batched(transA, transB, one($elty), A, B) - end - - function gemm_strided_batched!( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, N1}, - B::$array{$elty, N2}, beta::($elty), - C::$array{$elty, N3}, - Ai, Aj, Bi, Bj, Ci, Cj) where {N1, N2, N3} - - # (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an) - # |______lda______| |____K/M (Ai)_____| |___Aj____| - # (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn) - # |______ldb______| |____K/N (Bi)_____| |___Bj____| - # (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn) - # |______ldc______| |____K/N (Ci)_____| |___Cj____| - - Base.require_one_based_indexing(A, B, C) - BLAS.chkstride1(A, B, C) - - sa1, sa2, sa3 = collapsed_size(A, Ai, Aj) - sb1, sb2, sb3 = collapsed_size(B, Bi, Bj) - sc1, sc2, sc3 = collapsed_size(C, Ci, Cj) - - @assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C" - @assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C" - - m = transA == 'N' ? sa1 : sa2 - ka = transA == 'N' ? sa2 : sa1 - kb = transB == 'N' ? sb1 : sb2 - n = transB == 'N' ? sb2 : sb1 - - if m != sc1 || n != sc2 || ka != kb - throw(DimensionMismatch("A has size ($m,$ka,$sa3), B has size ($kb,$n,$sb3), C has size ($sc1, $sc2, $sc3)")) - end - - lda = max(1, stride(A, N1 - Ai - Aj + 1)) - ldb = max(1, stride(B, N2 - Bi - Bj + 1)) - ldc = max(1, stride(C, N3 - Ci - Cj + 1)) - - strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1) - strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1) - strideC = stride(C, N3 - Cj + 1) - batchCount = sc3 - - gemm_strided_batched_impl!( - transA, transB, m, n, ka, - alpha, A, lda, strideA, - B, ldb, strideB, beta, - C, ldc, strideC, batchCount) - - return C - end - - function gemm_strided_batched( - transA::Char, transB::Char, - alpha::($elty), A::$array{$elty, N1}, - B::$array{$elty, N2}, - Ai, Aj, Bi, Bj) where {N1, N2} - - m = noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2) - n = noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1) - sc3 = collapsed_size(A, Ai, Aj, 3) > collapsed_size(B, Bi, Bj, 3) ? - noncollapsed_size(A, Ai, Aj, 3) : - noncollapsed_size(B, Bi, Bj, 3) - - Ci = length(n) - Cj = length(sc3) - C = similar(B, (m..., n..., sc3...)) - return gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C, Ai, Aj, Bi, Bj, Ci, Cj) - end - - function gemm_strided_batched( - transA::Char, transB::Char, - A::$array{$elty, N1}, - B::$array{$elty, N2}, - Ai, Aj, Bi, Bj) where {N1, N2} - return gemm_strided_batched(transA, transB, one($elty), A, B, Ai, Aj, Bi, Bj) - end - +gemm_strided_batched( + transA::Char, transB::Char, A::T1, B::T2) where {ET, T1 <: AbstractArray{ET, 3}, T2 <: AbstractArray{ET, 3}} = + gemm_strided_batched(transA, transB, A, B, static(1), static(1), static(1), static(1)) +gemm_strided_batched( + transA::Char, transB::Char, alpha::ET, A::T1, B::T2 +) where {ET, T1 <: AbstractArray{ET, 3}, T2 <: AbstractArray{ET, 3}} = + gemm_strided_batched(transA, transB, alpha, A, B, static(1), static(1), static(1), static(1)) +gemm_strided_batched!( + transA::Char, transB::Char, alpha::ET, A::T1, B::T2, beta::ET, C::T3 +) where {ET, T1 <: AbstractArray{ET, 3}, T2 <: AbstractArray{ET, 3}, T3 <: AbstractArray{ET, 3}} = + gemm_strided_batched!( + transA, transB, alpha, A, B, beta, C, static(1), static(1), static(1), static(1), static(1), static(1)) + +gemm_strided_batched( + transA::Char, transB::Char, A::T1, B::T2, Ai, Aj, Bi, Bj +) where {ET, N1, N2, T1 <: AbstractArray{ET, N1}, T2 <: AbstractArray{ET, N2}} = + gemm_strided_batched(transA, transB, one(ET), A, B, Ai, Aj, Bi, Bj) +function gemm_strided_batched( + transA::Char, transB::Char, alpha::ET, A::T1, B::T2, Ai, Aj, Bi, Bj +) where {ET, N1, N2, T1 <: AbstractArray{ET, N1}, T2 <: AbstractArray{ET, N2}} + m = noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2) + n = noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1) + sc3 = collapsed_size(A, Ai, Aj, 3) > collapsed_size(B, Bi, Bj, 3) ? + noncollapsed_size(A, Ai, Aj, 3) : + noncollapsed_size(B, Bi, Bj, 3) + Ci = length(n) + Cj = length(sc3) + C = similar(B, (m..., n..., sc3...)) + return gemm_strided_batched!(transA, transB, alpha, A, B, zero(ET), C, Ai, Aj, Bi, Bj, Ci, Cj) +end +function gemm_strided_batched!( + transA::Char, transB::Char, + alpha::ET, A::T1, B::T2, beta::ET, C::T3, Ai, Aj, Bi, Bj, Ci, Cj +) where {ET, N1, N2, N3, T1 <: AbstractArray{ET, N1}, T2 <: AbstractArray{ET, N2}, T3 <: AbstractArray{ET, N3}} + # (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an) + # |______lda______| |____K/M (Ai)_____| |___Aj____| + # (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn) + # |______ldb______| |____K/N (Bi)_____| |___Bj____| + # (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn) + # |______ldc______| |____K/N (Ci)_____| |___Cj____| + Base.require_one_based_indexing(A, B, C) + BLAS.chkstride1(A, B, C) + + sa1, sa2, sa3 = collapsed_size(A, Ai, Aj) + sb1, sb2, sb3 = collapsed_size(B, Bi, Bj) + sc1, sc2, sc3 = collapsed_size(C, Ci, Cj) + @assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C" + @assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C" + + m = transA == 'N' ? sa1 : sa2 + ka = transA == 'N' ? sa2 : sa1 + kb = transB == 'N' ? sb1 : sb2 + n = transB == 'N' ? sb2 : sb1 + if m != sc1 || n != sc2 || ka != kb + throw(DimensionMismatch("A has size ($m, $ka, $sa3), B has size ($kb, $n, $sb3), C has size ($sc1, $sc2, $sc3)")) end + + lda = max(1, stride(A, N1 - Ai - Aj + 1)) + ldb = max(1, stride(B, N2 - Bi - Bj + 1)) + ldc = max(1, stride(C, N3 - Ci - Cj + 1)) + + strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1) + strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1) + strideC = stride(C, N3 - Cj + 1) + batchCount = sc3 + + gemm_strided_batched_impl!( + transA, transB, m, n, ka, + alpha, A, lda, strideA, + B, ldb, strideB, beta, + C, ldc, strideC, batchCount) + return C end diff --git a/test/collapseddims.jl b/test/collapseddims.jl index 9cef2b8..03cc45f 100644 --- a/test/collapseddims.jl +++ b/test/collapseddims.jl @@ -1,4 +1,4 @@ -if !USE_CUDA +if !USE_GPU @testset "CollapsedDim" begin using NeuralAttentionlib.Matmul using NeuralAttentionlib: collapseddims_fdim1, collapseddims_nonbatch, collapseddims_nonbatch_fdim1 diff --git a/test/functional.jl b/test/functional.jl index 2be1488..edeac52 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -13,7 +13,7 @@ layer_norm, rms_layer_norm, get_sincos_position_embeddings @testset "score" begin - if !USE_CUDA + if !USE_GPU @testset "AD" begin test_rrule(dot_product_score, randn(5, 3, 2), randn(5, 4, 2); check_inferred = false) test_rrule(dot_product_score, randn(5, 3, 2, 2), randn(5, 4, 2, 2)) @@ -84,7 +84,7 @@ end end - if !USE_CUDA + if !USE_GPU @testset "AD" begin test_rrule( scalar_relative_position_embedding, t5_bucketed_position_id(8, 20), randn(3, 8), @@ -182,7 +182,7 @@ @test with_rotary_position_embedding(x) ≈ naive_rotary_pe(x) @test with_rotary_position_embedding(256, x) ≈ naive_rotary_pe_w_dim(256, x) @test with_rotary_position_embedding(256)(x) ≈ naive_rotary_pe_w_dim(256, x) - if !USE_CUDA + if !USE_GPU @testset "AD" begin x = randn(512, 5, 3, 2) @test Zygote.gradient(x->sum(sin.(with_rotary_position_embedding(x))), x)[1] ≈ @@ -226,7 +226,7 @@ atol = 5e-1 ) - if !USE_CUDA + if !USE_GPU @testset "AD" begin g = randn(20) b = randn(20) @@ -258,7 +258,7 @@ @testset "attention" begin @testset "multihead_qkv_attention" begin - if !USE_CUDA + if !USE_GPU @testset "AD" begin for i = 1:3 a = randn(20, 3, 2) @@ -296,7 +296,7 @@ @test grad[2] ≈ ngrad[2] @test grad[3] ≈ ngrad[3] - if !USE_CUDA + if !USE_GPU @testset "AD" begin for i = 1:3 a = randn(30, 3, 2) diff --git a/test/mask.jl b/test/mask.jl index 47aca7a..921afee 100644 --- a/test/mask.jl +++ b/test/mask.jl @@ -285,7 +285,7 @@ @test_throws DimensionMismatch drandn(5, 4) .* (GenericAttenMask(drand(Bool, 3, 4)) | SymLengthMask([2])) end - if !USE_CUDA + if !USE_GPU @testset "AD" begin m = (LocalMask(1) | CausalMask() & !(BandPartMask(5,5)) | BiLengthMask([2,3], [3, 7])) diff --git a/test/matmul.jl b/test/matmul.jl index 15091f7..283cd3e 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -18,7 +18,7 @@ end uwcs(x) = size(unwrap_collapse(x)) - if USE_CUDA + if USE_GPU eltype_list = (Float64, Float32, Float16, ComplexF64, ComplexF32) else eltype_list = (Float64, Float32, ComplexF64, ComplexF32) @@ -178,7 +178,7 @@ end end - if !USE_CUDA + if !USE_GPU @testset "AD" begin test_rrule(matmul, randn(7,6,5), randn(6, 2), randn()) test_rrule(matmul, randn(7,6,5,4), randn(6), randn()) diff --git a/test/runtests.jl b/test/runtests.jl index ae3a323..b89f06d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using NeuralAttentionlib using Random using Flux using CUDA +using AMDGPU using NNlib using Static using ChainRulesCore @@ -23,24 +24,39 @@ include("old_impl/old_impl.jl") using .Old_Impl using .Old_Impl: batched_triu!, batched_tril! -function should_test_cuda() - e = get(ENV, "JL_PKG_TEST_CUDA", false) - e isa Bool && return e +function testing_gpu() + e = get(ENV, "JL_PKG_TEST_GPU", nothing) + isnothing(e) && return nothing if e isa String - x = tryparse(Bool, e) - return isnothing(x) ? false : x - else - return false + x = lowercase(e) + if isempty(x) + return nothing + elseif x == "cuda" + return :cuda + elseif x == "amdgpu" + return :amdgpu + end end + error("Unknown value for `JL_PKG_TEST_GPU`: $x") end -const USE_CUDA = @show should_test_cuda() - -if USE_CUDA - CUDA.allowscalar(false) +const GPUBACKEND = testing_gpu() +if isnothing(GPUBACKEND) + const USE_GPU = false +else + const USE_GPU = true + if GPUBACKEND == :cuda + using CUDA + CUDA.allowscalar(false) + elseif GPUBACKEND == :amdgpu + using AMDGPU + AMDGPU.allowscalar(false) + end end +@show GPUBACKEND +@show USE_GPU -device(x) = USE_CUDA ? gpu(x) : x +device(x) = USE_GPU ? gpu(x) : x drandn(arg...) = randn(arg...) |> device drand(arg...) = rand(arg...) |> device