Skip to content

Commit

Permalink
use pkg ext and drop lts julia
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Apr 19, 2024
1 parent 40922f8 commit 347710d
Show file tree
Hide file tree
Showing 14 changed files with 308 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
29 changes: 16 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
name = "NeuralAttentionlib"
uuid = "12afc1b8-fad6-47e1-9132-84abc478905f"
authors = ["chengchingwen <[email protected]>"]
version = "0.2.12"
version = "0.2.13"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences="26cc04aa-876d-5657-8c51-4c34ba976000"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NeuralAttentionlibCUDAExt = "CUDA"
NeuralAttentionlibFiniteDifferences = "FiniteDifferences"
NeuralAttentionlibZygoteExt = "Zygote"

[compat]
Adapt = "3.3"
CUDA = "3, 4"
CUDA = "5"
ChainRulesCore = "1.3"
GPUArrays = "8"
GPUArraysCore = "0.1"
NNlib = "0.7, 0.8"
NNlibCUDA = "0.2"
Requires = "1.1"
FiniteDifferences = "0.12"
NNlib = "0.9"
Static = "0.7, 0.8"
julia = "1.6"
Zygote = "0.6"
julia = "1.10"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand Down
265 changes: 265 additions & 0 deletions ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
module NeuralAttentionlibCUDAExt

using NeuralAttentionlib
using NeuralAttentionlib.Adapt
using NeuralAttentionlib: AbstractArrayMask, Indexer, GetIndexer
using CUDA
using CUDA.GPUArrays
using CUDA.GPUArrays.GPUArraysCore

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

GPUArraysCore.backend(T::Type{<:NAlib.CollapsedDimsArray{E, <:CuArray}}) where E = GPUArraysCore.backend(CuArray{E, 3})

function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
GPUArrays.gpu_call(B, A) do ctx, B, A
idx = GPUArrays.@cartesianidx A
@inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
return
end
return B
end

import CUDA.CUBLAS
for (fname, elty) in
((:cublasDgemmStridedBatched,:Float64),
(:cublasSgemmStridedBatched,:Float32),
(:cublasHgemmStridedBatched,:Float16),
(:cublasZgemmStridedBatched,:ComplexF64),
(:cublasCgemmStridedBatched,:ComplexF32))
@eval begin

@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::CuPtr{$elty}, lda::Int, strideA::Int,
ptrB::CuPtr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::CuPtr{$elty}, ldc::Int, strideC::Int, batchCount::Int)

CUBLAS.$fname(CUBLAS.handle(),
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
return nothing
end

end
end

for (elty, array) in (
(:Float16, :CuArray),
)
@eval begin
@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::Ptr{$elty}, lda::Int, strideA::Int,
ptrB::Ptr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::Ptr{$elty}, ldc::Int, strideC::Int, batchCount::Int)

# https://github.com/FluxML/NNlib.jl/blob/cd3851d31e95020e77e67f80fb6402b5b87db1e6/src/gemm.jl#L91-L139
n_threads = min(Threads.nthreads(), 1 + max(m * k * batchCount, n * k * batchCount) ÷ 8000)
if n_threads > 1
old_threads = get_num_threads()
set_num_threads(1)
Threads.@sync for bs in Iterators.partition(1:batchCount, cld(batchCount, n_threads))
Threads.@spawn for b in bs
ptrAi = ptrA + (b - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (b - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (b - 1) * strideC * sizeof($elty)

NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
set_num_threads(old_threads)
else
for i = 1:batchCount
ptrAi = ptrA + (i - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (i - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (i - 1) * strideC * sizeof($elty)
NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
return nothing
end

@inline function NeuralAttentionlib.gemm_strided_batched_impl!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), A::$array{$elty}, lda::Int, strideA::Int,
B::$array{$elty}, ldb::Int, strideB::Int, beta::($elty),
C::$array{$elty}, ldc::Int, strideC::Int, batchCount::Int)

ptrA = pointer(A)
ptrB = pointer(B)
ptrC = pointer(C)
GC.@preserve A B C begin
NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
end
return C
end

@inline function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3}, beta::($elty),
C::$array{$elty, 3})

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"

m = size(A, transA == 'N' ? 1 : 2)
ka = size(A, transA == 'N' ? 2 : 1)
kb = size(B, transB == 'N' ? 1 : 2)
n = size(B, transB == 'N' ? 2 : 1)

if m != size(C,1) || n != size(C,2) || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$(size(A, 3))), B has size ($kb,$n,$(size(B, 3))), C has size $(size(C))"))
end

lda = max(1, stride(A,2))
ldb = max(1, stride(B,2))
ldc = max(1, stride(C,2))

strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(C, 3)

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
return NeuralAttentionlib.gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, 3},
B::$array{$elty, 3})
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B)
end

function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2}, beta::($elty),
C::$array{$elty, N3},
Ai, Aj, Bi, Bj, Ci, Cj) where {N1, N2, N3}

# (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an)
# |______lda______| |____K/M (Ai)_____| |___Aj____|
# (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn)
# |______ldb______| |____K/N (Bi)_____| |___Bj____|
# (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn)
# |______ldc______| |____K/N (Ci)_____| |___Cj____|

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)

sa1, sa2, sa3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj)
sb1, sb2, sb3 = NeuralAttentionlib.collapsed_size(B, Bi, Bj)
sc1, sc2, sc3 = NeuralAttentionlib.collapsed_size(C, Ci, Cj)

@assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C"
@assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C"

m = transA == 'N' ? sa1 : sa2
ka = transA == 'N' ? sa2 : sa1
kb = transB == 'N' ? sb1 : sb2
n = transB == 'N' ? sb2 : sb1

if m != sc1 || n != sc2 || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$sa3), B has size ($kb,$n,$sb3), C has size ($sc1, $sc2, $sc3)"))
end

lda = max(1, stride(A, N1 - Ai - Aj + 1))
ldb = max(1, stride(B, N2 - Bi - Bj + 1))
ldc = max(1, stride(C, N3 - Ci - Cj + 1))

strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1)
strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1)
strideC = stride(C, N3 - Cj + 1)
batchCount = sc3

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}

m = NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2)
n = NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1)
sc3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj, 3) > NeuralAttentionlib.collapsed_size(B, Bi, Bj, 3) ?
NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, 3) :
NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, 3)

Ci = length(n)
Cj = length(sc3)
C = similar(B, (m..., n..., sc3...))
return NeuralAttentionlib.gemm_strided_batched!(
transA, transB, alpha, A, B, zero($elty), C, Ai, Aj, Bi, Bj, Ci, Cj)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B, Ai, Aj, Bi, Bj)
end

end
end

NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true

Adapt.adapt(to::CUDA.KernelAdaptor, m::AbstractArrayMask) =
Indexer{typeof(m)}(map(Base.Fix1(Adapt.adapt, to), GetIndexer(m).__fields))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.FlipMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask),))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.CombinedMask) =
Indexer{typeof(m)}((f = adapt(to, m.f), masks = map(Base.Fix1(adapt, to), m.masks)))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.BatchedMask) =
Indexer{typeof(m)}((mask = adapt(to, m.mask), batch_dim = static(m.batch_dim)))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.RepeatMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask), num = m.num))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.BiSequenceMask) =
Indexer{typeof(m)}((q_mask = adapt(to, m.q_mask), k_mask = adapt(to, m.k_mask)))

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module NeuralAttentionlibFiniteDifferences

using NeuralAttentionlib
using NeuralAttentionlib: CollapsedDimsArray, collapseddims
using FiniteDifferences

function FiniteDifferences.to_vec(X::CollapsedDimsArray)
x_vec, back = to_vec(collapseddims(X))
s = size(parent(X))
ni = X.ni
nj = X.nj
function CollapsedDimsArray_from_vec(x_vec)
return CollapsedDimsArray(reshape(back(x_vec), s), ni, nj)
end
return x_vec, CollapsedDimsArray_from_vec
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module NeuralAttentionlibZygoteExt

using NeuralAttentionlib
using Zygote

Zygote.unbroadcast(x::NeuralAttentionlib.AbstractMask, _) = nothing

end
6 changes: 0 additions & 6 deletions src/NeuralAttentionlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@ module NeuralAttentionlib

using Static

using CUDA
using Adapt
import Adapt: adapt_structure, adapt
import GPUArraysCore
using ChainRulesCore

using NNlib
using NNlibCUDA

using Requires

export multihead_qkv_attention, Functional, Masks

Expand All @@ -23,7 +18,6 @@ include("./matmul/gemm.jl")
include("./matmul/matmul.jl")
include("./matmul/grad.jl")
include("./matmul/scaled_matmul.jl")
include("./matmul/gpu.jl")

# attention score masking
include("./mask/indexer.jl")
Expand Down
4 changes: 0 additions & 4 deletions src/mask/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,3 @@ function ChainRulesCore.rrule(config::RuleConfig, pf::PrefixedFunction{typeof(ap
pullback(Ȳ) = (NoTangent(), mask_pullback(Ȳ)[4])
return y, pullback
end

@init @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
Zygote.unbroadcast(x::AbstractMask, _) = nothing
end
2 changes: 0 additions & 2 deletions src/mask/mask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,5 @@ Base.@propagate_inbounds Base.getindex(m::M, I::Integer...) where {M <: Union{<:
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, i::CartesianIndex) = m[Tuple(i)]
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, I::Tuple) = m[I...]

Adapt.adapt(to::CUDA.Adaptor, m::AbstractArrayMask) = Indexer{typeof(m)}(map(Base.Fix1(Adapt.adapt, to), GetIndexer(m).__fields))

randomness(::AbstractMask) = static(false)
require_dest(::AbstractMask) = static(false)
Loading

0 comments on commit 347710d

Please sign in to comment.