Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use pkg ext and drop lts julia #26

Merged
merged 3 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
os:
- ubuntu-latest
Expand Down
34 changes: 19 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
name = "NeuralAttentionlib"
uuid = "12afc1b8-fad6-47e1-9132-84abc478905f"
authors = ["chengchingwen <[email protected]>"]
version = "0.2.12"
version = "0.2.13"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences="26cc04aa-876d-5657-8c51-4c34ba976000"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NeuralAttentionlibCUDAExt = "CUDA"
NeuralAttentionlibFiniteDifferences = "FiniteDifferences"
NeuralAttentionlibZygoteExt = "Zygote"

[compat]
Adapt = "3.3"
CUDA = "3, 4"
Adapt = "3.3, 4"
CUDA = "5"
ChainRulesCore = "1.3"
GPUArrays = "8"
GPUArraysCore = "0.1"
NNlib = "0.7, 0.8"
NNlibCUDA = "0.2"
Requires = "1.1"
FiniteDifferences = "0.12"
NNlib = "0.9"
Static = "0.7, 0.8"
julia = "1.6"
Zygote = "0.6"
julia = "1.10"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -39,4 +43,4 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[targets]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"]
265 changes: 265 additions & 0 deletions ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
module NeuralAttentionlibCUDAExt

using NeuralAttentionlib
using NeuralAttentionlib.Adapt
using NeuralAttentionlib: AbstractArrayMask, Indexer, GetIndexer
using CUDA
using CUDA.GPUArrays
using CUDA.GPUArrays.GPUArraysCore

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

GPUArraysCore.backend(T::Type{<:NAlib.CollapsedDimsArray{E, <:CuArray}}) where E = GPUArraysCore.backend(CuArray{E, 3})

function NeuralAttentionlib.batched_transpose_f!(f, B::AnyGPUArray{T, 3}, A::AnyGPUArray{T, 3}) where T
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) && axes(A,3) == axes(B,3) || throw(DimensionMismatch(string(f)))
GPUArrays.gpu_call(B, A) do ctx, B, A
idx = GPUArrays.@cartesianidx A
@inbounds B[idx[2], idx[1], idx[3]] = f(A[idx[1], idx[2], idx[3]])
return
end
return B
end

import CUDA.CUBLAS
for (fname, elty) in
((:cublasDgemmStridedBatched,:Float64),
(:cublasSgemmStridedBatched,:Float32),
(:cublasHgemmStridedBatched,:Float16),
(:cublasZgemmStridedBatched,:ComplexF64),
(:cublasCgemmStridedBatched,:ComplexF32))
@eval begin

@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::CuPtr{$elty}, lda::Int, strideA::Int,
ptrB::CuPtr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::CuPtr{$elty}, ldc::Int, strideC::Int, batchCount::Int)

CUBLAS.$fname(CUBLAS.handle(),
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
return nothing
end

end
end

for (elty, array) in (
(:Float16, :CuArray),
)
@eval begin
@inline function NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), ptrA::Ptr{$elty}, lda::Int, strideA::Int,
ptrB::Ptr{$elty}, ldb::Int, strideB::Int, beta::($elty),
ptrC::Ptr{$elty}, ldc::Int, strideC::Int, batchCount::Int)

# https://github.com/FluxML/NNlib.jl/blob/cd3851d31e95020e77e67f80fb6402b5b87db1e6/src/gemm.jl#L91-L139
n_threads = min(Threads.nthreads(), 1 + max(m * k * batchCount, n * k * batchCount) ÷ 8000)
if n_threads > 1
old_threads = get_num_threads()
set_num_threads(1)
Threads.@sync for bs in Iterators.partition(1:batchCount, cld(batchCount, n_threads))
Threads.@spawn for b in bs
ptrAi = ptrA + (b - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (b - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (b - 1) * strideC * sizeof($elty)

NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
set_num_threads(old_threads)
else
for i = 1:batchCount
ptrAi = ptrA + (i - 1) * strideA * sizeof($elty)
ptrBi = ptrB + (i - 1) * strideB * sizeof($elty)
ptrCi = ptrC + (i - 1) * strideC * sizeof($elty)
NeuralAttentionlib.unsafe_gemm!(transA, transB, m, n, k,
alpha, ptrAi, lda,
ptrBi, ldb, beta,
ptrCi, ldc)
end
end
return nothing
end

@inline function NeuralAttentionlib.gemm_strided_batched_impl!(
transA::Char, transB::Char,
m::Int, n::Int, k::Int,
alpha::($elty), A::$array{$elty}, lda::Int, strideA::Int,
B::$array{$elty}, ldb::Int, strideB::Int, beta::($elty),
C::$array{$elty}, ldc::Int, strideC::Int, batchCount::Int)

ptrA = pointer(A)
ptrB = pointer(B)
ptrC = pointer(C)
GC.@preserve A B C begin
NeuralAttentionlib.unsafe_gemm_strided_batched!(
transA, transB, m, n, k,
alpha, ptrA, lda, strideA,
ptrB, ldb, strideB, beta,
ptrC, ldc, strideC, batchCount)
end
return C
end

@inline function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3}, beta::($elty),
C::$array{$elty, 3})

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)
@assert size(A, 3) == size(C, 3) || size(A, 3) == 1 "batch size mismatch: A != C"
@assert size(B, 3) == size(C, 3) || size(B, 3) == 1 "batch size mismatch: B != C"

m = size(A, transA == 'N' ? 1 : 2)
ka = size(A, transA == 'N' ? 2 : 1)
kb = size(B, transB == 'N' ? 1 : 2)
n = size(B, transB == 'N' ? 2 : 1)

if m != size(C,1) || n != size(C,2) || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$(size(A, 3))), B has size ($kb,$n,$(size(B, 3))), C has size $(size(C))"))
end

lda = max(1, stride(A,2))
ldb = max(1, stride(B,2))
ldc = max(1, stride(C,2))

strideA = size(A, 3) == 1 ? 0 : stride(A, 3)
strideB = size(B, 3) == 1 ? 0 : stride(B, 3)
strideC = stride(C, 3)
batchCount = size(C, 3)

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, 3},
B::$array{$elty, 3})
C = similar(B, (size(A, transA == 'N' ? 1 : 2), size(B, transB == 'N' ? 2 : 1), max(size(A, 3), size(B, 3))))
return NeuralAttentionlib.gemm_strided_batched!(transA, transB, alpha, A, B, zero($elty), C)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, 3},
B::$array{$elty, 3})
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B)
end

function NeuralAttentionlib.gemm_strided_batched!(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2}, beta::($elty),
C::$array{$elty, N3},
Ai, Aj, Bi, Bj, Ci, Cj) where {N1, N2, N3}

# (a1, a2, ..., ai-1, ai, ai+1, ..., aj-1, aj, ..., an)
# |______lda______| |____K/M (Ai)_____| |___Aj____|
# (b1, b2, ..., bi-1, bi, bi+1, ..., bj-1, bj, ..., bn)
# |______ldb______| |____K/N (Bi)_____| |___Bj____|
# (c1, c2, ..., ci-1, ci, ci+1, ..., cj-1, cj, ..., cn)
# |______ldc______| |____K/N (Ci)_____| |___Cj____|

Base.require_one_based_indexing(A, B, C)
BLAS.chkstride1(A, B, C)

sa1, sa2, sa3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj)
sb1, sb2, sb3 = NeuralAttentionlib.collapsed_size(B, Bi, Bj)
sc1, sc2, sc3 = NeuralAttentionlib.collapsed_size(C, Ci, Cj)

@assert sa3 == sc3 || sa3 == 1 "batch size mismatch: A != C"
@assert sb3 == sc3 || sb3 == 1 "batch size mismatch: B != C"

m = transA == 'N' ? sa1 : sa2
ka = transA == 'N' ? sa2 : sa1
kb = transB == 'N' ? sb1 : sb2
n = transB == 'N' ? sb2 : sb1

if m != sc1 || n != sc2 || ka != kb
throw(DimensionMismatch("A has size ($m,$ka,$sa3), B has size ($kb,$n,$sb3), C has size ($sc1, $sc2, $sc3)"))
end

lda = max(1, stride(A, N1 - Ai - Aj + 1))
ldb = max(1, stride(B, N2 - Bi - Bj + 1))
ldc = max(1, stride(C, N3 - Ci - Cj + 1))

strideA = sa3 == 1 ? 0 : stride(A, N1 - Aj + 1)
strideB = sb3 == 1 ? 0 : stride(B, N2 - Bj + 1)
strideC = stride(C, N3 - Cj + 1)
batchCount = sc3

NeuralAttentionlib.gemm_strided_batched_impl!(
transA, transB, m, n, ka,
alpha, A, lda, strideA,
B, ldb, strideB, beta,
C, ldc, strideC, batchCount)

return C
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
alpha::($elty), A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}

m = NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, transA == 'N' ? 1 : 2)
n = NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, transB == 'N' ? 2 : 1)
sc3 = NeuralAttentionlib.collapsed_size(A, Ai, Aj, 3) > NeuralAttentionlib.collapsed_size(B, Bi, Bj, 3) ?
NeuralAttentionlib.noncollapsed_size(A, Ai, Aj, 3) :
NeuralAttentionlib.noncollapsed_size(B, Bi, Bj, 3)

Ci = length(n)
Cj = length(sc3)
C = similar(B, (m..., n..., sc3...))
return NeuralAttentionlib.gemm_strided_batched!(
transA, transB, alpha, A, B, zero($elty), C, Ai, Aj, Bi, Bj, Ci, Cj)
end

function NeuralAttentionlib.gemm_strided_batched(
transA::Char, transB::Char,
A::$array{$elty, N1},
B::$array{$elty, N2},
Ai, Aj, Bi, Bj) where {N1, N2}
return NeuralAttentionlib.gemm_strided_batched(transA, transB, one($elty), A, B, Ai, Aj, Bi, Bj)
end

end
end

NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true

Adapt.adapt(to::CUDA.KernelAdaptor, m::AbstractArrayMask) =
Indexer{typeof(m)}(map(Base.Fix1(Adapt.adapt, to), GetIndexer(m).__fields))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.FlipMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask),))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.CombinedMask) =
Indexer{typeof(m)}((f = adapt(to, m.f), masks = map(Base.Fix1(adapt, to), m.masks)))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.BatchedMask) =
Indexer{typeof(m)}((mask = adapt(to, m.mask), batch_dim = static(m.batch_dim)))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.RepeatMask) = Indexer{typeof(m)}((mask = adapt(to, m.mask), num = m.num))
Adapt.adapt(to::CUDA.KernelAdaptor, m::NAlib.BiSequenceMask) =
Indexer{typeof(m)}((q_mask = adapt(to, m.q_mask), k_mask = adapt(to, m.k_mask)))

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module NeuralAttentionlibFiniteDifferences

using NeuralAttentionlib
using NeuralAttentionlib: CollapsedDimsArray, collapseddims
using FiniteDifferences

function FiniteDifferences.to_vec(X::CollapsedDimsArray)
x_vec, back = to_vec(collapseddims(X))
s = size(parent(X))
ni = X.ni
nj = X.nj
function CollapsedDimsArray_from_vec(x_vec)
return CollapsedDimsArray(reshape(back(x_vec), s), ni, nj)
end
return x_vec, CollapsedDimsArray_from_vec
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module NeuralAttentionlibZygoteExt

using NeuralAttentionlib
using Zygote

Zygote.unbroadcast(x::NeuralAttentionlib.AbstractMask, _) = nothing

end
6 changes: 0 additions & 6 deletions src/NeuralAttentionlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@ module NeuralAttentionlib

using Static

using CUDA
using Adapt
import Adapt: adapt_structure, adapt
import GPUArraysCore
using ChainRulesCore

using NNlib
using NNlibCUDA

using Requires

export multihead_qkv_attention, Functional, Masks

Expand All @@ -23,7 +18,6 @@ include("./matmul/gemm.jl")
include("./matmul/matmul.jl")
include("./matmul/grad.jl")
include("./matmul/scaled_matmul.jl")
include("./matmul/gpu.jl")

# attention score masking
include("./mask/indexer.jl")
Expand Down
4 changes: 0 additions & 4 deletions src/mask/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,3 @@ function ChainRulesCore.rrule(config::RuleConfig, pf::PrefixedFunction{typeof(ap
pullback(Ȳ) = (NoTangent(), mask_pullback(Ȳ)[4])
return y, pullback
end

@init @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
Zygote.unbroadcast(x::AbstractMask, _) = nothing
end
Loading
Loading