From 3e28c4d088a6ab061c5cbda17047e71c0e5ac738 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Sun, 30 Jun 2024 02:23:13 +0800 Subject: [PATCH 1/5] support metal matmul --- .../NeuralAttentionlibAMDGPUExt.jl | 8 +- .../NeuralAttentionlibCUDAExt.jl | 8 +- .../NeuralAttentionlibMetalExt.jl | 90 +++++++++++++++++++ src/matmul/matmul.jl | 1 + test/matmul.jl | 8 +- test/old_impl/batched_gemm.jl | 57 ------------ test/old_impl/batched_gemm_gpu.jl | 20 ----- test/old_impl/batchedmul.jl | 6 +- test/old_impl/old_impl.jl | 1 - test/runtests.jl | 60 +++++++------ 10 files changed, 138 insertions(+), 121 deletions(-) create mode 100644 ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl delete mode 100644 test/old_impl/batched_gemm.jl delete mode 100644 test/old_impl/batched_gemm_gpu.jl diff --git a/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl b/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl index 0b1aab3..469f493 100644 --- a/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl +++ b/ext/NeuralAttentionlibAMDGPUExt/NeuralAttentionlibAMDGPUExt.jl @@ -5,10 +5,6 @@ using NeuralAttentionlib.Adapt using NeuralAttentionlib: TypedPtr using AMDGPU -import LinearAlgebra -import LinearAlgebra.BLAS -using LinearAlgebra.BLAS: get_num_threads, set_num_threads - const NAlib = NeuralAttentionlib import AMDGPU.rocBLAS @@ -39,7 +35,7 @@ for (fname, elty) in end end -NeuralAttentionlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray -NeuralAttentionlib.check_strided_gemm_type(A::ROCArray{Float16}) = true +NAlib.ptrtypetag(::AMDGPU.ROCArrayBackend) = ROCArray +NAlib.check_strided_gemm_type(A::ROCArray{Float16}) = true end diff --git a/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl b/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl index a6dc80b..0427caa 100644 --- a/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl +++ b/ext/NeuralAttentionlibCUDAExt/NeuralAttentionlibCUDAExt.jl @@ -5,10 +5,6 @@ using NeuralAttentionlib.Adapt using NeuralAttentionlib: TypedPtr using CUDA -import LinearAlgebra -import LinearAlgebra.BLAS -using LinearAlgebra.BLAS: get_num_threads, set_num_threads - const NAlib = NeuralAttentionlib import CUDA.CUBLAS @@ -39,7 +35,7 @@ for (fname, elty) in end end -NeuralAttentionlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray -NeuralAttentionlib.check_strided_gemm_type(A::CuArray{Float16}) = true +NAlib.ptrtypetag(::CUDA.CuArrayBackend) = CuArray +NAlib.check_strided_gemm_type(A::CuArray{Float16}) = true end diff --git a/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl b/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl new file mode 100644 index 0000000..f94d80f --- /dev/null +++ b/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl @@ -0,0 +1,90 @@ +module NeuralAttentionlibMetalExt + +using NeuralAttentionlib +using NeuralAttentionlib.Adapt +using NeuralAttentionlib.NNlib +using Metal +using Metal.MPS: MPSMatrix, MTLBuffer, NSUInteger, MPSMatrixDescriptor, id + +const NAlib = NeuralAttentionlib + +function _mpsmatrix(arr, desc, offset) + mat = MPS.@objc [MPSMatrix alloc]::id{MPSMatrix} + obj = MPS.MPSMatrix(mat) + finalizer(MPS.release, obj) + MPS.@objc [obj::MPS.id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer} + offset:offset::NSUInteger + descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} + return obj +end +function mpsmatrix(arr::MtlArray{T}, lda, stride, batch) where T + sz = sizeof(T) + N = length(arr) + n_cols_rows = iszero(stride) ? N : stride + n_cols = lda + n_rows = div(n_cols_rows, n_cols) + n_matrices = batch + row_bytes = sz * n_cols + matrix_bytes = iszero(stride) ? 0 : row_bytes * n_rows + desc = MPS.MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, matrix_bytes, T) + offset = arr.offset * sz + return _mpsmatrix(arr, desc, offset), (n_cols, n_rows, n_matrices) +end + +for elty in (:Float32, :Float16) + @eval begin + @inline function NAlib.gemm_strided_batched_impl!( + transA::Char, transB::Char, + m::Integer, n::Integer, k::Integer, + alpha::($elty), A::MtlArray{$elty, N1}, lda::Integer, strideA::Integer, + B::MtlArray{$elty, N2}, ldb::Integer, strideB::Integer, beta::($elty), + C::MtlArray{$elty, N3}, ldc::Integer, strideC::Integer, batchCount::Integer + ) where {N1, N2, N3} + + transpose_a = transA != 'N' + transpose_b = transB != 'N' + + mps_a, shp_A = mpsmatrix(A, lda, strideA, batchCount) + mps_b, shp_B = mpsmatrix(B, ldb, strideB, batchCount) + mps_c, shp_C = mpsmatrix(C, ldc, strideC, batchCount) + + cols_a = shp_A[transpose_a ? 1 : 2] + cols_c, rows_c = shp_C + + mat_mul_kernel = MPS.MPSMatrixMultiplication(Metal.current_device(), + transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta) + + cmdbuf = Metal.MTLCommandBuffer(Metal.global_queue(Metal.current_device())) + MPS.encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c) + Metal.commit!(cmdbuf) + return C + end + end +end + +# TODO: +function NNlib._batched_gemm!(::Type{<:MtlArray}, transA::Char, transB::Char, α::Number, A, B, β::Number, C) + a = if transA == 'T' + batched_transpose(A) + elseif transA == 'C' + batched_adjoint(A) + else + A + end + b = if transB == 'T' + batched_transpose(B) + elseif transB == 'C' + batched_adjoint(B) + else + B + end + return NNlib.batched_mul_generic!(C, a, b, α, β) +end + +NAlib.use_gemm_strided_batched(A::MtlArray{ComplexF32}, B::MtlArray{ComplexF32}) = false +NAlib.use_gemm_strided_batched(A::NAlib.CollapsedDimsArray{ComplexF32, <:MtlMatrix{ComplexF32}}, B::NAlib.CollapsedDimsArray{ComplexF32, <:MtlMatrix{ComplexF32}}) = false +NAlib.ptrtypetag(::Metal.mtlArrayBackend) = MtlArray +NAlib.check_strided_gemm_type(A::MtlArray{Float16}) = true + + +end diff --git a/src/matmul/matmul.jl b/src/matmul/matmul.jl index fb04694..2ec6c39 100644 --- a/src/matmul/matmul.jl +++ b/src/matmul/matmul.jl @@ -101,6 +101,7 @@ check_strided_gemm_type(ca::CollapsedDimsArray) = check_strided_gemm_type(parent check_strided_gemm_type(A::AbstractArray) = eltype(A) <: BLAS.BlasFloat check_strided_gemm_type(A::AbstractArray{Float16}) = false +use_gemm_strided_batched(A::AbstractArray, B::CollapsedDimsArray) = use_gemm_strided_batched(unwrap_collapse(B), A) function use_gemm_strided_batched(A::AbstractArray{TA}, B::AbstractArray{TB}) where {TA, TB} if NNlib.is_strided(A) && NNlib.is_strided(B) return check_strided_gemm_type(A) && check_strided_gemm_type(B) diff --git a/test/matmul.jl b/test/matmul.jl index 283cd3e..fca2f01 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -14,12 +14,16 @@ cy = y isa NeuralAttentionlib.Collapsed ? collapseddims(y) : y @assert !(cx isa CollapsedDimsArray) "$(typeof(cx))" @assert !(cy isa CollapsedDimsArray) "$(typeof(cy))" - return matmul(x, y, s) ≈ batched_mul(cx, cy) .* s + return matmul(x, y, s) ≈ device(batched_mul(cpu(cx), cpu(cy)) .* s) end uwcs(x) = size(unwrap_collapse(x)) if USE_GPU - eltype_list = (Float64, Float32, Float16, ComplexF64, ComplexF32) + if GPUBACKEND == :metal + eltype_list = (Float32, Float16, ComplexF32) + else + eltype_list = (Float64, Float32, Float16, ComplexF64, ComplexF32) + end else eltype_list = (Float64, Float32, ComplexF64, ComplexF32) end diff --git a/test/old_impl/batched_gemm.jl b/test/old_impl/batched_gemm.jl deleted file mode 100644 index 903c5f6..0000000 --- a/test/old_impl/batched_gemm.jl +++ /dev/null @@ -1,57 +0,0 @@ -#= - borrow from https://github.com/Roger-luo/BatchedRoutines.jl -=# - -import LinearAlgebra - -#batched cpu gemm by BatchedRoutines.jl -for (gemm, elty) in - ((:dgemm_,:Float64), - (:sgemm_,:Float32),) - @eval begin - function batched_gemm!(transA::AbstractChar, - transB::AbstractChar, - alpha::($elty), - A::AbstractArray{$elty, 3}, - B::AbstractArray{$elty, 3}, - beta::($elty), - C::AbstractArray{$elty, 3}) - @assert !Base.has_offset_axes(A, B, C) - @assert size(A, 3) == size(B, 3) == size(C, 3) "batch size mismatch" - m = size(A, transA == 'N' ? 1 : 2) - ka = size(A, transA == 'N' ? 2 : 1) - kb = size(B, transB == 'N' ? 1 : 2) - n = size(B, transB == 'N' ? 2 : 1) - if ka != kb || m != size(C,1) || n != size(C,2) - throw(DimensionMismatch("A has size ($m,$ka), B has size ($kb,$n), C has size $(size(C))")) - end - LinearAlgebra.BLAS.chkstride1(A) - LinearAlgebra.BLAS.chkstride1(B) - LinearAlgebra.BLAS.chkstride1(C) - - ptrA = Base.unsafe_convert(Ptr{$elty}, A) - ptrB = Base.unsafe_convert(Ptr{$elty}, B) - ptrC = Base.unsafe_convert(Ptr{$elty}, C) - - for k in 1:size(A, 3) - ccall((LinearAlgebra.BLAS.@blasfunc($gemm), LinearAlgebra.BLAS.libblas), Cvoid, - (Ref{UInt8}, Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, - Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, Ref{LinearAlgebra.BLAS.BlasInt}, - Ptr{$elty}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{$elty}, Ptr{$elty}, - Ref{LinearAlgebra.BLAS.BlasInt}), - transA, transB, m, n, - ka, alpha, ptrA, max(1,stride(A,2)), - ptrB, max(1,stride(B,2)), beta, ptrC, - max(1,stride(C,2))) - - ptrA += size(A, 1) * size(A, 2) * sizeof($elty) - ptrB += size(B, 1) * size(B, 2) * sizeof($elty) - ptrC += size(C, 1) * size(C, 2) * sizeof($elty) - end - - C - end - end -end - - diff --git a/test/old_impl/batched_gemm_gpu.jl b/test/old_impl/batched_gemm_gpu.jl deleted file mode 100644 index 6da7371..0000000 --- a/test/old_impl/batched_gemm_gpu.jl +++ /dev/null @@ -1,20 +0,0 @@ -#batched CuArray gemm by BatchedRoutines.jl -function batched_gemm!(transA::AbstractChar, - transB::AbstractChar, - alpha::Float32, - A::CUDA.CuArray{Float32, 3}, - B::CUDA.CuArray{Float32, 3}, - beta::Float32, - C::CUDA.CuArray{Float32, 3}) - CUDA.CUBLAS.gemm_strided_batched!(transA, transB, alpha, A, B, beta, C) -end - -function batched_gemm!(transA::AbstractChar, - transB::AbstractChar, - alpha::Float64, - A::CUDA.CuArray{Float64, 3}, - B::CUDA.CuArray{Float64, 3}, - beta::Float64, - C::CUDA.CuArray{Float64, 3}) - CUDA.CUBLAS.gemm_strided_batched!(transA, transB, alpha, A, B, beta, C) -end diff --git a/test/old_impl/batchedmul.jl b/test/old_impl/batchedmul.jl index a95d2aa..5b5fd73 100644 --- a/test/old_impl/batchedmul.jl +++ b/test/old_impl/batchedmul.jl @@ -1,6 +1,4 @@ -include("./batched_gemm.jl") -include("./batched_gemm_gpu.jl") - +using NeuralAttentionlib using ZygoteRules: @adjoint function batchedmul(a::Abstract3DTensor{T}, b::Abstract3DTensor{T}; @@ -27,7 +25,7 @@ function batched_mul!(C::Abstract3DTensor{T}, A::Abstract3DTensor{T}, B::Abstrac transA::Bool = false, transB::Bool = false) where T At = transA ? 'T' : 'N' Bt = transB ? 'T' : 'N' - batched_gemm!(At, Bt, one(T), A, B, zero(T), C) + NeuralAttentionlib.gemm_strided_batched!(At, Bt, one(T), A, B, zero(T), C) C end diff --git a/test/old_impl/old_impl.jl b/test/old_impl/old_impl.jl index ce6fe5c..332e55b 100644 --- a/test/old_impl/old_impl.jl +++ b/test/old_impl/old_impl.jl @@ -1,7 +1,6 @@ module Old_Impl import NeuralAttentionlib -using CUDA export MultiheadAttention diff --git a/test/runtests.jl b/test/runtests.jl index b89f06d..23592a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,29 +1,5 @@ -using Test -using NeuralAttentionlib - -using Random +using Test, Pkg using Flux -using CUDA -using AMDGPU -using NNlib -using Static -using ChainRulesCore -using ChainRulesTestUtils - -const tests = [ - "collapseddims", - "matmul", - "mask", - "functional", - "mha", -] - -Random.seed!(0) - -include("old_impl/old_impl.jl") -using .Old_Impl -using .Old_Impl: batched_triu!, batched_tril! - function testing_gpu() e = get(ENV, "JL_PKG_TEST_GPU", nothing) isnothing(e) && return nothing @@ -35,6 +11,8 @@ function testing_gpu() return :cuda elseif x == "amdgpu" return :amdgpu + elseif x == "metal" + return :metal end end error("Unknown value for `JL_PKG_TEST_GPU`: $x") @@ -46,16 +24,48 @@ if isnothing(GPUBACKEND) else const USE_GPU = true if GPUBACKEND == :cuda + Pkg.add(["CUDA"]) using CUDA CUDA.allowscalar(false) + Flux.gpu_backend!("CUDA") elseif GPUBACKEND == :amdgpu + Pkg.add(["AMDGPU"]) using AMDGPU AMDGPU.allowscalar(false) + Flux.gpu_backend!("AMDGPU") + elseif GPUBACKEND == :metal + Pkg.add(["Metal"]) + using Metal + Metal.allowscalar(false) + Flux.gpu_backend!("Metal") end + end @show GPUBACKEND @show USE_GPU +using NeuralAttentionlib + +using Random +using NNlib +using Static +using ChainRulesCore +using ChainRulesTestUtils + +const tests = [ + "collapseddims", + "matmul", + "mask", + "functional", + "mha", +] + +Random.seed!(0) + +include("old_impl/old_impl.jl") +using .Old_Impl +using .Old_Impl: batched_triu!, batched_tril! + device(x) = USE_GPU ? gpu(x) : x drandn(arg...) = randn(arg...) |> device From 433fca5a2ffbf052e88d905ae41da3a60297cbc5 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Tue, 2 Jul 2024 20:16:23 +0800 Subject: [PATCH 2/5] update test --- test/functional.jl | 11 +++++-- test/mask.jl | 14 ++++----- test/old_impl/batched_tril.jl | 56 +++++++++++++++++------------------ 3 files changed, 43 insertions(+), 38 deletions(-) diff --git a/test/functional.jl b/test/functional.jl index 8e5b865..aec40ba 100644 --- a/test/functional.jl +++ b/test/functional.jl @@ -223,8 +223,8 @@ @testset "alibi position embedding" begin x1 = dzeros(10, 10, 3, 2); x2 = dzeros(10, 10, 3, 2); - x1 .= NeuralAttentionlib._build_alibi(nothing, CollapsedDimsArray(randn(10, 5, 2, 3, 2), 2, 2)) - x2 .= NeuralAttentionlib._build_alibi(Masks.BatchedMask(Masks.GenericAttenMask(trues(10, 10))), CollapsedDimsArray(randn(10, 5, 2, 3, 2), 2, 2)) + x1 .= NeuralAttentionlib._build_alibi(nothing, CollapsedDimsArray(drandn(10, 5, 2, 3, 2), 2, 2)) + x2 .= NeuralAttentionlib._build_alibi(Masks.BatchedMask(Masks.GenericAttenMask(device(trues(10, 10)))), CollapsedDimsArray(drandn(10, 5, 2, 3, 2), 2, 2)) @test x1 ≈ x2 if !USE_GPU @testset "AD" begin @@ -313,7 +313,12 @@ end @testset "l2norm" begin - naive_l2norm(x) = x ./ sqrt.(sum(x .^ 2; dims=1)) + square(x) = x .^ 2 + function ChainRulesCore.rrule(::typeof(square), x) + square_back(Ȳ) = (NoTangent(), convert(eltype(Ȳ), 2) .* x .* Ȳ) + return square(x), square_back + end + naive_l2norm(x) = x ./ sqrt.(sum(square(x); dims=1)) x = drandn(512, 3, 2) @test l2norm(x) ≈ naive_l2norm(x) @test Zygote.gradient(x->sum(sin.(l2norm(x))), x)[1] ≈ Zygote.gradient(x->sum(sin.(naive_l2norm(x))), x)[1] diff --git a/test/mask.jl b/test/mask.jl index 1baa2dc..c1c371a 100644 --- a/test/mask.jl +++ b/test/mask.jl @@ -3,9 +3,9 @@ using NeuralAttentionlib.Masks using NeuralAttentionlib: getmask, lengths, AttenMask - causal(x) = batched_triu!(copy(x), 0) |> device - trilu(x, d) = batched_tril!(batched_triu!(copy(x), -d), d) |> device - bandpart(x, l, u) = batched_tril!(batched_triu!(copy(x), -l), u) |> device + causal(x) = batched_triu!(cpu(copy(x)), 0) |> device + trilu(x, d) = batched_tril!(batched_triu!(cpu(copy(x)), -d), d) |> device + bandpart(x, l, u) = batched_tril!(batched_triu!(cpu(copy(x)), -l), u) |> device test_random(x, p) = isapprox(sum(x .* RandomMask(p)) / length(x), 1-p, atol=1e-1) @@ -62,7 +62,7 @@ end return y |> device end - grow_rlength(len1, len2, n) = reverse!(reverse!(grow_length(len1, len2, n); dims=1); dims=2) + grow_rlength(len1, len2, n) = reverse!(reverse!(cpu(grow_length(len1, len2, n)); dims=1); dims=2) |> device @testset "array mask" begin a = dones(Int, 10, 10) @@ -93,7 +93,7 @@ @test b .* BiLengthMask(bmaskq_b, bmaskk_b) == grow_length(bmaskk_b, bmaskq_b, 10) @test c .* BiLengthMask(bmaskq_c, bmaskk_c) == grow_length(bmaskk_c, bmaskq_c, 10) - rev!(x) = reverse!(reverse!(x; dims=1); dims=2) + rev!(x) = reverse!(reverse!(cpu(x); dims=1); dims=2) |> device @test a .* RevSymLengthMask(smask_a) == rev!(a .* SymLengthMask(smask_a)) @test b .* RevSymLengthMask(smask_b) == rev!(b .* SymLengthMask(smask_b)) @test c .* RevSymLengthMask(smask_c) == rev!(c .* SymLengthMask(smask_c)) @@ -206,8 +206,8 @@ end @testset "sequence mask" begin - a0 = hcat([1, 1, 1, 1, 0], [1,1,1,0,0]) - ra0 = hcat([0, 1, 1, 1, 1], [0,0,1,1,1]) + a0 = hcat(Bool[1, 1, 1, 1, 0], Bool[1,1,1,0,0]) + ra0 = hcat(Bool[0, 1, 1, 1, 1], Bool[0,0,1,1,1]) a = device(reshape(a0, (1, 5, 2))) ra = device(reshape(ra0, (1, 5, 2))) b = device([4,3]) diff --git a/test/old_impl/batched_tril.jl b/test/old_impl/batched_tril.jl index d2d9adc..f1caa68 100644 --- a/test/old_impl/batched_tril.jl +++ b/test/old_impl/batched_tril.jl @@ -1,37 +1,37 @@ using LinearAlgebra: tril!, triu! function batched_tril!(x::A, d) where {T, N, A <: AbstractArray{T, N}} - if N < 2 - error("MethodError: no method matching tril!(::Array{Float64,1})") - elseif N == 2 - return tril!(x, d) - else - s = size(x) - m = (s[1], s[2]) - ms = s[1] * s[2] - len = Int(length(x) // ms) - Wt = Core.apply_type(A.name.wrapper, T, 2) - for i = 1:len - tril!(Base.unsafe_wrap(Wt, Base.pointer(x, (ms * (i - 1) + 1)), m), d) + if N < 2 + error("MethodError: no method matching tril!(::Array{Float64,1})") + elseif N == 2 + return tril!(x, d) + else + s = size(x) + m = (s[1], s[2]) + ms = s[1] * s[2] + len = Int(length(x) // ms) + C = CartesianIndices(Base.tail(Base.tail(s))) + for i = 1:len + tril!(@view(x[:, :, C[i]]), d) + end + return x end - return x - end end function batched_triu!(x::A, d) where {T, N, A <: AbstractArray{T, N}} - if N < 2 - error("MethodError: no method matching triu!(::Array{Float64,1})") - elseif N == 2 - return triu!(x, d) - else - s = size(x) - m = (s[1], s[2]) - ms = s[1] * s[2] - len = Int(length(x) // ms) - Wt = Core.apply_type(A.name.wrapper, T, 2) - for i = 1:len - triu!(Base.unsafe_wrap(Wt, Base.pointer(x, (ms * (i - 1) + 1)), m), d) + if N < 2 + error("MethodError: no method matching triu!(::Array{Float64,1})") + elseif N == 2 + return triu!(x, d) + else + s = size(x) + m = (s[1], s[2]) + ms = s[1] * s[2] + len = Int(length(x) // ms) + C = CartesianIndices(Base.tail(Base.tail(s))) + for i = 1:len + triu!(@view(x[:, :, C[i]]), d) + end + return x end - return x - end end From ff386053bd60fc4b5fcff23ddd3a1fd1e16f4d63 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Tue, 2 Jul 2024 20:16:59 +0800 Subject: [PATCH 3/5] avoid Float64 in position embedding --- src/functional/position_embedding/sincos.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/functional/position_embedding/sincos.jl b/src/functional/position_embedding/sincos.jl index bff850f..1011b15 100644 --- a/src/functional/position_embedding/sincos.jl +++ b/src/functional/position_embedding/sincos.jl @@ -4,7 +4,7 @@ using ChainRulesCore default_position_func(hidden_size) = Base.Fix1(default_position_func, static(Int32(hidden_size))) @inline function default_position_func(hidden_size, i) j = (0x1 - i) << 0x3 #8 * (1 - i) - return 1e1 ^ (j / hidden_size) + return 1f1 ^ (Float32(j) / hidden_size) end function sincos_position_embed(f, ::Val{hidden_size}, pos, indices, ::Val{normalized}) where {hidden_size, normalized} @@ -17,9 +17,9 @@ function sincos_position_embed(f, ::Val{hidden_size}, pos, indices, ::Val{normal half = hidden_size >> 0x1 if half << 0x1 != hidden_size r = sin(_pos * f(half + 0x1)) - return y * inv(sqrt(half + r)) + return y * inv(sqrt(Float32(half + r))) else - return y * inv(sqrt(half)) + return y * inv(sqrt(Float32(half))) end else return y From f3c4ddc4addf1eb9b7d8a2631de60658413c7da0 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Tue, 2 Jul 2024 20:17:47 +0800 Subject: [PATCH 4/5] update project toml --- Project.toml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1c7dcbf..fe9a3ad 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -25,6 +26,7 @@ NeuralAttentionlibAMDGPUExt = ["AMDGPU", "GPUArrays"] NeuralAttentionlibCUDAExt = ["CUDA", "GPUArrays"] NeuralAttentionlibFiniteDifferencesExt = "FiniteDifferences" NeuralAttentionlibGPUArraysExt = "GPUArrays" +NeuralAttentionlibMetalExt = ["Metal", "GPUArrays"] NeuralAttentionlibZygoteExt = "Zygote" [compat] @@ -36,18 +38,18 @@ FiniteDifferences = "0.12" FuncTransforms = "0.1" GPUArrays = "10" GPUArraysCore = "0.1" +Metal = "1.1" NNlib = "0.9" Static = "0.7, 0.8" Zygote = "0.6" julia = "1.10" [extras] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -55,4 +57,4 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [targets] -test = ["Test", "Flux", "MacroTools", "ZygoteRules", "CUDA", "AMDGPU", "ChainRulesTestUtils", "Random", "Pickle", "ZipFile", "Statistics"] +test = ["Test", "Flux", "MacroTools", "ZygoteRules", "ChainRulesTestUtils", "Random", "Pickle", "Pkg", "ZipFile", "Statistics"] From a1ee8609ab65a81bc77d93d2ebd648ad8506411d Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Tue, 2 Jul 2024 20:30:17 +0800 Subject: [PATCH 5/5] update metal ext --- .../NeuralAttentionlibMetalExt.jl | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl b/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl index f94d80f..e032738 100644 --- a/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl +++ b/ext/NeuralAttentionlibMetalExt/NeuralAttentionlibMetalExt.jl @@ -4,19 +4,9 @@ using NeuralAttentionlib using NeuralAttentionlib.Adapt using NeuralAttentionlib.NNlib using Metal -using Metal.MPS: MPSMatrix, MTLBuffer, NSUInteger, MPSMatrixDescriptor, id const NAlib = NeuralAttentionlib -function _mpsmatrix(arr, desc, offset) - mat = MPS.@objc [MPSMatrix alloc]::id{MPSMatrix} - obj = MPS.MPSMatrix(mat) - finalizer(MPS.release, obj) - MPS.@objc [obj::MPS.id{MPSMatrix} initWithBuffer:arr::id{MTLBuffer} - offset:offset::NSUInteger - descriptor:desc::id{MPSMatrixDescriptor}]::id{MPSMatrix} - return obj -end function mpsmatrix(arr::MtlArray{T}, lda, stride, batch) where T sz = sizeof(T) N = length(arr) @@ -28,7 +18,7 @@ function mpsmatrix(arr::MtlArray{T}, lda, stride, batch) where T matrix_bytes = iszero(stride) ? 0 : row_bytes * n_rows desc = MPS.MPSMatrixDescriptor(n_rows, n_cols, n_matrices, row_bytes, matrix_bytes, T) offset = arr.offset * sz - return _mpsmatrix(arr, desc, offset), (n_cols, n_rows, n_matrices) + return MPS.MPSMatrix(arr, desc, offset), (n_cols, n_rows, n_matrices) end for elty in (:Float32, :Float16) @@ -51,10 +41,9 @@ for elty in (:Float32, :Float16) cols_a = shp_A[transpose_a ? 1 : 2] cols_c, rows_c = shp_C - mat_mul_kernel = MPS.MPSMatrixMultiplication(Metal.current_device(), - transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta) + mat_mul_kernel = MPS.MPSMatrixMultiplication(device(), transpose_b, transpose_a, rows_c, cols_c, cols_a, alpha, beta) - cmdbuf = Metal.MTLCommandBuffer(Metal.global_queue(Metal.current_device())) + cmdbuf = Metal.MTLCommandBuffer(Metal.global_queue(device())) MPS.encode!(cmdbuf, mat_mul_kernel, mps_b, mps_a, mps_c) Metal.commit!(cmdbuf) return C