Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal support #31

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NeuralAttentionlibAMDGPUExt = ["AMDGPU", "GPUArrays"]
NeuralAttentionlibCUDAExt = ["CUDA", "GPUArrays"]
NeuralAttentionlibFiniteDifferencesExt = "FiniteDifferences"
NeuralAttentionlibGPUArraysExt = "GPUArrays"
NeuralAttentionlibMetalExt = ["Metal", "GPUArrays"]
NeuralAttentionlibZygoteExt = "Zygote"

[compat]
Expand All @@ -36,23 +38,23 @@ FiniteDifferences = "0.12"
FuncTransforms = "0.1"
GPUArrays = "10"
GPUArraysCore = "0.1"
Metal = "1.1"
NNlib = "0.9"
Static = "0.7, 0.8"
Zygote = "0.6"
julia = "1.10"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[targets]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "AMDGPU", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"]
test = ["Test", "Flux", "MacroTools", "ZygoteRules", "ChainRulesTestUtils", "Random", "Pickle", "Pkg", "ZipFile", "Statistics"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ using NeuralAttentionlib.Adapt
using NeuralAttentionlib: TypedPtr
using AMDGPU

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

import AMDGPU.rocBLAS
Expand Down Expand Up @@ -39,7 +35,7 @@ for (fname, elty) in
end
end

NeuralAttentionlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray
NeuralAttentionlib.check_strided_gemm_type(A::ROCArray{Float16}) = true
NAlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray
NAlib.check_strided_gemm_type(A::ROCArray{Float16}) = true

end
8 changes: 2 additions & 6 deletions ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ using NeuralAttentionlib.Adapt
using NeuralAttentionlib: TypedPtr
using CUDA

import LinearAlgebra
import LinearAlgebra.BLAS
using LinearAlgebra.BLAS: get_num_threads, set_num_threads

const NAlib = NeuralAttentionlib

import CUDA.CUBLAS
Expand Down Expand Up @@ -39,7 +35,7 @@ for (fname, elty) in
end
end

NeuralAttentionlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray
NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true
NAlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray
NAlib.check_strided_gemm_type(A::CuArray{Float16}) = true

end
79 changes: 79 additions & 0 deletions ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module NeuralAttentionlibMetalExt

using NeuralAttentionlib
using NeuralAttentionlib.Adapt
using NeuralAttentionlib.NNlib
using Metal

const NAlib = NeuralAttentionlib

function mpsmatrix(arr::MtlArray{T}, lda, stride, batch) where T
sz = sizeof(T)
N = length(arr)
n_cols_rows = iszero(stride) ? N : stride
n_cols = lda
n_rows = div(n_cols_rows, n_cols)
n_matrices = batch
row_bytes = sz * n_cols
matrix_bytes = iszero(stride) ? 0 : row_bytes * n_rows
desc = MPS.MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, matrix_bytes, T)
offset = arr.offset * sz
return MPS.MPSMatrix(arr, desc, offset), (n_cols, n_rows, n_matrices)
end

for elty in (:Float32, :Float16)
@eval begin
@inline function NAlib.gemm_strided_batched_impl!(
transA::Char, transB::Char,
m::Integer, n::Integer, k::Integer,
alpha::($elty), A::MtlArray{$elty, N1}, lda::Integer, strideA::Integer,
B::MtlArray{$elty, N2}, ldb::Integer, strideB::Integer, beta::($elty),
C::MtlArray{$elty, N3}, ldc::Integer, strideC::Integer, batchCount::Integer
) where {N1, N2, N3}

transpose_a = transA != 'N'
transpose_b = transB != 'N'

mps_a, shp_A = mpsmatrix(A, lda, strideA, batchCount)
mps_b, shp_B = mpsmatrix(B, ldb, strideB, batchCount)
mps_c, shp_C = mpsmatrix(C, ldc, strideC, batchCount)

cols_a = shp_A[transpose_a ? 1 : 2]
cols_c, rows_c = shp_C

mat_mul_kernel = MPS.MPSMatrixMultiplication(device(), transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta)

cmdbuf = Metal.MTLCommandBuffer(Metal.global_queue(device()))
MPS.encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c)
Metal.commit!(cmdbuf)
return C
end
end
end

# TODO:
function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C)
a = if transA == 'T'
batched_transpose(A)
elseif transA == 'C'
batched_adjoint(A)
else
A
end
b = if transB == 'T'
batched_transpose(B)
elseif transB == 'C'
batched_adjoint(B)
else
B
end
return NNlib.batched_mul_generic!(C, a, b, α, β)
end

NAlib.use_gemm_strided_batched(A::MtlArray{ComplexF32}, B::MtlArray{ComplexF32}) = false
NAlib.use_gemm_strided_batched(A::NAlib.CollapsedDimsArray{ComplexF32, <:MtlMatrix{ComplexF32}}, B::NAlib.CollapsedDimsArray{ComplexF32, <:MtlMatrix{ComplexF32}}) = false
NAlib.ptrtypetag(::Metal.mtlArrayBackend) = MtlArray
NAlib.check_strided_gemm_type(A::MtlArray{Float16}) = true


end
6 changes: 3 additions & 3 deletions src/functional/position_embedding/sincos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore
default_position_func(hidden_size) = Base.Fix1(default_position_func, static(Int32(hidden_size)))
@inline function default_position_func(hidden_size, i)
j = (0x1 - i) << 0x3 #8 * (1 - i)
return 1e1 ^ (j / hidden_size)
return 1f1 ^ (Float32(j) / hidden_size)
end

function sincos_position_embed(f, ::Val{hidden_size}, pos, indices, ::Val{normalized}) where {hidden_size, normalized}
Expand All @@ -17,9 +17,9 @@ function sincos_position_embed(f, ::Val{hidden_size}, pos, indices, ::Val{normal
half = hidden_size >> 0x1
if half << 0x1 != hidden_size
r = sin(_pos * f(half + 0x1))
return y * inv(sqrt(half + r))
return y * inv(sqrt(Float32(half + r)))
else
return y * inv(sqrt(half))
return y * inv(sqrt(Float32(half)))
end
else
return y
Expand Down
1 change: 1 addition & 0 deletions src/matmul/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ check_strided_gemm_type(ca::CollapsedDimsArray) = check_strided_gemm_type(parent
check_strided_gemm_type(A::AbstractArray) = eltype(A) <: BLAS.BlasFloat
check_strided_gemm_type(A::AbstractArray{Float16}) = false

use_gemm_strided_batched(A::AbstractArray, B::CollapsedDimsArray) = use_gemm_strided_batched(unwrap_collapse(B), A)
function use_gemm_strided_batched(A::AbstractArray{TA}, B::AbstractArray{TB}) where {TA, TB}
if NNlib.is_strided(A) && NNlib.is_strided(B)
return check_strided_gemm_type(A) && check_strided_gemm_type(B)
Expand Down
11 changes: 8 additions & 3 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@
@testset "alibi position embedding" begin
x1 = dzeros(10, 10, 3, 2);
x2 = dzeros(10, 10, 3, 2);
x1 .= NeuralAttentionlib._build_alibi(nothing, CollapsedDimsArray(randn(10, 5, 2, 3, 2), 2, 2))
x2 .= NeuralAttentionlib._build_alibi(Masks.BatchedMask(Masks.GenericAttenMask(trues(10, 10))), CollapsedDimsArray(randn(10, 5, 2, 3, 2), 2, 2))
x1 .= NeuralAttentionlib._build_alibi(nothing, CollapsedDimsArray(drandn(10, 5, 2, 3, 2), 2, 2))
x2 .= NeuralAttentionlib._build_alibi(Masks.BatchedMask(Masks.GenericAttenMask(device(trues(10, 10)))), CollapsedDimsArray(drandn(10, 5, 2, 3, 2), 2, 2))
@test x1 ≈ x2
if !USE_GPU
@testset "AD" begin
Expand Down Expand Up @@ -313,7 +313,12 @@
end

@testset "l2norm" begin
naive_l2norm(x) = x ./ sqrt.(sum(x .^ 2; dims=1))
square(x) = x .^ 2
function ChainRulesCore.rrule(::typeof(square), x)
square_back(Ȳ) = (NoTangent(), convert(eltype(Ȳ), 2) .* x .* Ȳ)
return square(x), square_back
end
naive_l2norm(x) = x ./ sqrt.(sum(square(x); dims=1))
x = drandn(512, 3, 2)
@test l2norm(x) ≈ naive_l2norm(x)
@test Zygote.gradient(x->sum(sin.(l2norm(x))), x)[1] ≈ Zygote.gradient(x->sum(sin.(naive_l2norm(x))), x)[1]
Expand Down
14 changes: 7 additions & 7 deletions test/mask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
using NeuralAttentionlib.Masks
using NeuralAttentionlib: getmask, lengths, AttenMask

causal(x) = batched_triu!(copy(x), 0) |> device
trilu(x, d) = batched_tril!(batched_triu!(copy(x), -d), d) |> device
bandpart(x, l, u) = batched_tril!(batched_triu!(copy(x), -l), u) |> device
causal(x) = batched_triu!(cpu(copy(x)), 0) |> device
trilu(x, d) = batched_tril!(batched_triu!(cpu(copy(x)), -d), d) |> device
bandpart(x, l, u) = batched_tril!(batched_triu!(cpu(copy(x)), -l), u) |> device

test_random(x, p) = isapprox(sum(x .* RandomMask(p)) / length(x), 1-p, atol=1e-1)

Expand Down Expand Up @@ -62,7 +62,7 @@
end
return y |> device
end
grow_rlength(len1, len2, n) = reverse!(reverse!(grow_length(len1, len2, n); dims=1); dims=2)
grow_rlength(len1, len2, n) = reverse!(reverse!(cpu(grow_length(len1, len2, n)); dims=1); dims=2) |> device

@testset "array mask" begin
a = dones(Int, 10, 10)
Expand Down Expand Up @@ -93,7 +93,7 @@
@test b .* BiLengthMask(bmaskq_b, bmaskk_b) == grow_length(bmaskk_b, bmaskq_b, 10)
@test c .* BiLengthMask(bmaskq_c, bmaskk_c) == grow_length(bmaskk_c, bmaskq_c, 10)

rev!(x) = reverse!(reverse!(x; dims=1); dims=2)
rev!(x) = reverse!(reverse!(cpu(x); dims=1); dims=2) |> device
@test a .* RevSymLengthMask(smask_a) == rev!(a .* SymLengthMask(smask_a))
@test b .* RevSymLengthMask(smask_b) == rev!(b .* SymLengthMask(smask_b))
@test c .* RevSymLengthMask(smask_c) == rev!(c .* SymLengthMask(smask_c))
Expand Down Expand Up @@ -206,8 +206,8 @@
end

@testset "sequence mask" begin
a0 = hcat([1, 1, 1, 1, 0], [1,1,1,0,0])
ra0 = hcat([0, 1, 1, 1, 1], [0,0,1,1,1])
a0 = hcat(Bool[1, 1, 1, 1, 0], Bool[1,1,1,0,0])
ra0 = hcat(Bool[0, 1, 1, 1, 1], Bool[0,0,1,1,1])
a = device(reshape(a0, (1, 5, 2)))
ra = device(reshape(ra0, (1, 5, 2)))
b = device([4,3])
Expand Down
8 changes: 6 additions & 2 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
cy = y isa NeuralAttentionlib.Collapsed ? collapseddims(y) : y
@assert !(cx isa CollapsedDimsArray) "$(typeof(cx))"
@assert !(cy isa CollapsedDimsArray) "$(typeof(cy))"
return matmul(x, y, s) ≈ batched_mul(cx, cy) .* s
return matmul(x, y, s) ≈ device(batched_mul(cpu(cx), cpu(cy)) .* s)
end
uwcs(x) = size(unwrap_collapse(x))

if USE_GPU
eltype_list = (Float64, Float32, Float16, ComplexF64, ComplexF32)
if GPUBACKEND == :metal
eltype_list = (Float32, Float16, ComplexF32)
else
eltype_list = (Float64, Float32, Float16, ComplexF64, ComplexF32)
end
else
eltype_list = (Float64, Float32, ComplexF64, ComplexF32)
end
Expand Down
57 changes: 0 additions & 57 deletions test/old_impl/batched_gemm.jl

This file was deleted.

20 changes: 0 additions & 20 deletions test/old_impl/batched_gemm_gpu.jl

This file was deleted.

Loading
Loading