Skip to content

Commit

Permalink
Merge pull request #10 from MasonProtter/BLAS2
Browse files Browse the repository at this point in the history
BLAS2 Support
  • Loading branch information
MasonProtter authored Apr 4, 2020
2 parents 3e56d10 + a45343a commit 7081a22
Show file tree
Hide file tree
Showing 8 changed files with 529 additions and 49 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ os:
- linux
julia:
- 1.3
- 1.4
- nightly

matrix:
Expand Down
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
name = "Gaius"
uuid = "bffe22d1-cb55-4f4e-ac2c-f4dd4bf58912"
authors = ["MasonProtter <[email protected]>"]
version = "0.1.0"
version = "0.2.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"

[compat]
julia = "1.3"
julia = "1.3, 1.4"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "BenchmarkTools"]
26 changes: 19 additions & 7 deletions src/Gaius.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
module Gaius

using LoopVectorization: @avx, VectorizationBase.AbstractStridedPointer, VectorizationBase.gesp, VectorizationBase.vload, VectorizationBase.vstore!, VectorizationBase.AVX512F
using LoopVectorization: @avx, VectorizationBase.AbstractStridedPointer, VectorizationBase.gesp, VectorizationBase.gep, VectorizationBase.vload, VectorizationBase.vstore!, VectorizationBase.AVX512F
import LoopVectorization: @avx, VectorizationBase.stridedpointer
using StructArrays: StructArray
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: LinearAlgebra, Adjoint, Transpose

using UnsafeArrays: UnsafeArrays, @uviews, UnsafeArray

export blocked_mul, blocked_mul!

const DEFAULT_BLOCK_SIZE = AVX512F ? 96 : 64

const Eltypes = Union{Float64, Float32, Int64, Int32, Int16}
const MatTypesC{T <: Eltypes} = Union{Matrix{T}, SubArray{T, 2, <: Array}} # C for Column Major
const MatTypesR{T <: Eltypes} = Union{LinearAlgebra.Adjoint{T,<:MatTypesC{T}}, LinearAlgebra.Transpose{T,<:MatTypesC{T}}} # R for Row Major
const MatTypes{T <: Eltypes} = Union{MatTypesC{T}, MatTypesR{T}}

# Note this does not support changing the number of threads at runtime
macro _spawn(ex)
if Threads.nthreads() > 1
Expand All @@ -31,6 +28,21 @@ macro _sync(ex)
end

include("pointermatrix.jl")

Eltypes = Union{Float64, Float32, Int64, Int32, Int16}
MatTypesC{T} = Union{Matrix{T},
SubArray{T, 2, <: AbstractArray},
PointerMatrix{T},
UnsafeArray{T, 2}} # C for Column Major
MatTypesR{T} = Union{Adjoint{T,<:MatTypesC{T}},
Transpose{T,<:MatTypesC{T}}} # R for Row Major
MatTypes{ T} = Union{MatTypesC{T}, MatTypesR{T}}

VecTypes{T} = Union{Vector{T}, SubArray{T, 1, <:Array}}
CoVecTypes{T} = Union{Adjoint{T, <:VecTypes{T}},
Transpose{T, <:VecTypes{T}}}


include("matmul.jl")
include("block_operations.jl")
include("kernels.jl")
Expand Down
66 changes: 65 additions & 1 deletion src/block_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ function block_mat_vec_mul!(C, A, B, sz)
_mul_add!(C21, A22, B21, sz)
end
end
function block_mat_vec_mul!(C::VecTypes, A, B::VecTypes, sz)
@inbounds @views begin
C11 = C[1:sz ];
C21 = C[sz+1:end];

A11 = A[1:sz, 1:sz]; A12 = A[1:sz, sz+1:end]
A21 = A[sz+1:end, 1:sz]; A22 = A[sz+1:end, sz+1:end]

B11 = B[1:sz ];
B21 = B[sz+1:end];
end
@_sync begin
@_spawn begin
gemm_kernel!(C11, A11, B11)
_mul_add!(C11, A12, B21, sz)
end
_mul!( C21, A21, B11, sz)
_mul_add!(C21, A22, B21, sz)
end
end

function block_covec_mat_mul!(C, A, B, sz)
@inbounds @views begin
Expand Down Expand Up @@ -99,9 +119,10 @@ function block_vec_covec_mul!(C, A, B, sz)
end
end


function block_covec_vec_mul!(C, A, B, sz)
@inbounds @views begin
A11 = A[1:end, 1:sz]; A12 = A[1:sz, sz+1:end]
A11 = A[1:end, 1:sz]; A12 = A[1:end, sz+1:end]

B11 = B[1:sz, 1:end];
B21 = B[sz+1:end, 1:end];
Expand All @@ -111,6 +132,17 @@ function block_covec_vec_mul!(C, A, B, sz)
_mul_add!(C, A12, B21, sz)
end

function block_covec_vec_mul!(C::VecTypes, A, B::VecTypes, sz)
@inbounds @views begin
A11 = A[1:end, 1:sz]; A12 = A[1:end, sz+1:end]

B11 = B[1:sz ];
B21 = B[sz+1:end];
end
gemm_kernel!(C, A11, B11)
_mul_add!(C, A12, B21, sz)
end

#----------------------------------------------------------------
#----------------------------------------------------------------
# Block Matrix addition-multiplication
Expand Down Expand Up @@ -165,6 +197,27 @@ function block_mat_vec_mul_add!(C, A, B, sz, ::Val{factor} = Val(1)) where {fact
end
end

function block_mat_vec_mul_add!(C::VecTypes, A, B::VecTypes, sz, ::Val{factor} = Val(1)) where {factor}
@inbounds @views begin
C11 = C[1:sz ];
C21 = C[sz+1:end];

A11 = A[1:sz, 1:sz]; A12 = A[1:sz, sz+1:end]
A21 = A[sz+1:end, 1:sz]; A22 = A[sz+1:end, sz+1:end]

B11 = B[1:sz ];
B21 = B[sz+1:end];
end
@_sync begin
@_spawn begin
add_gemm_kernel!(C11, A11, B11, Val(factor))
_mul_add!(C11, A12, B21, sz, Val(factor))
end
_mul_add!(C21, A21, B11, sz, Val(factor))
_mul_add!(C21, A22, B21, sz, Val(factor))
end
end

function block_covec_mat_mul_add!(C, A, B, sz, ::Val{factor} = Val(1)) where {factor}
@inbounds @views begin
C11 = C[1:end, 1:sz]; C12 = C[1:end, sz+1:end]
Expand Down Expand Up @@ -218,3 +271,14 @@ function block_covec_vec_mul_add!(C, A, B, sz, ::Val{factor} = Val(1)) where {fa
add_gemm_kernel!(C, A11, B11, Val(factor))
_mul_add!(C, A12, B21, sz, Val(factor))
end

function block_covec_vec_mul_add!(C::VecTypes, A, B::VecTypes, sz, ::Val{factor} = Val(1)) where {factor}
@inbounds @views begin
A11 = A[1:end, 1:sz]; A12 = A[1:end, sz+1:end]

B11 = B[1:sz ];
B21 = B[sz+1:end];
end
add_gemm_kernel!(C, A11, B11, Val(factor))
_mul_add!(C, A12, B21, sz, Val(factor))
end
108 changes: 95 additions & 13 deletions src/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,118 @@ function gemm_kernel!(C, A, B)
end
end

function add_gemm_kernel!(C, A, B, ::Val{factor}) where {factor}
if factor == 1
_add_gemm_kernel!(C, A, B)
elseif factor == -1
_sub_gemm_kernel!(C, A, B)
else
_add_gemm_kernel!(C, C, B, Val(factor))
function add_gemm_kernel!(C::MatTypes, A::MatTypes, B::MatTypes)
@avx for n 1:size(A, 1), m 1:size(B, 2)
Cₙₘ = zero(eltype(C))
for k 1:size(A, 2)
Cₙₘ += A[n,k] * B[k,m]
end
C[n,m] += Cₙₘ
end
end

function _add_gemm_kernel!(C, A, B)
add_gemm_kernel!(C::MatTypes, A::MatTypes, B::MatTypes, ::Val{1}) = add_gemm_kernel!(C, A, B)

function add_gemm_kernel!(C::MatTypes, A::MatTypes, B::MatTypes, ::Val{-1})
@avx for n 1:size(A, 1), m 1:size(B, 2)
Cₙₘ = zero(eltype(C))
for k 1:size(A, 2)
Cₙₘ += A[n,k] * B[k,m]
Cₙₘ -= A[n,k] * B[k,m]
end
C[n,m] += Cₙₘ
end
end

function _sub_gemm_kernel!(C, A, B)
function add_gemm_kernel!(C::MatTypes, A::MatTypes, B::MatTypes, ::Val{factor}) where {factor}
@avx for n 1:size(A, 1), m 1:size(B, 2)
Cₙₘ = zero(eltype(C))
for k 1:size(A, 2)
Cₙₘ -= A[n,k] * B[k,m]
Cₙₘ += factor * A[n,k] * B[k,m]
end
C[n,m] += Cₙₘ
end
end

_add_gemm_kernel!(C, A, B, ::Val{1}) = _add_gemm_kernel!(C, A, B)
_add_gemm_kernel!(C, A, B, ::Val{-1}) = _sub_gemm_kernel!(C, A, B)
#____________

function gemm_kernel!(u::VecTypes, A::MatTypes, v::VecTypes)
@avx for n 1:size(A, 1)
uₙ = zero(eltype(u))
for k 1:size(A, 2)
uₙ += A[n,k] * v[k]
end
u[n] = uₙ
end
end

function add_gemm_kernel!(u::VecTypes, A::MatTypes, v::VecTypes)
@avx for n 1:size(A, 1)
uₙ = zero(eltype(u))
for k 1:size(A, 2)
uₙ += A[n,k] * v[k]
end
u[n] += uₙ
end
end

function add_gemm_kernel!(u::VecTypes, A::MatTypes, v::VecTypes, ::Val{-1})
@avx for n 1:size(A, 1)
uₙ = zero(eltype(u))
for k 1:size(A, 2)
uₙ -= A[n,k] * v[k]
end
u[n] += uₙ
end
end

function add_gemm_kernel!(u::VecTypes, A::MatTypes, v::VecTypes, ::Val{factor}) where {factor}
@avx for n 1:size(A, 1)
uₙ = zero(eltype(u))
for k 1:size(A, 2)
uₙ += factor * A[n,k] * v[k]
end
u[n] += uₙ
end
end

#____________

function gemm_kernel!(u::CoVecTypes, v::CoVecTypes, A::MatTypes)
@avx for m 1:size(A, 2)
uₘ = zero(eltype(u))
for k 1:size(A, 1)
uₘ += v[k] * A[k, m]
end
u[m] = uₘ
end
end

function add_gemm_kernel!(u::CoVecTypes, v::CoVecTypes, A::MatTypes)
@avx for m 1:size(A, 2)
uₘ = zero(eltype(u))
for k 1:size(A, 1)
uₘ += v[k] * A[k, m]
end
u[m] += uₘ
end
end

function add_gemm_kernel(u::CoVecTypes, v::CoVecTypes, A::MatTypes, ::Val{-1})
@avx for m 1:size(A, 2)
uₘ = zero(eltype(u))
for k 1:size(A, 1)
uₘ -= v[k] * A[k, m]
end
u[m] += uₘ
end
end

function add_gemm_kernel!(u::CoVecTypes, v::CoVecTypes, A::MatTypes, ::Val{factor}) where {factor}
@avx for m 1:size(A, 2)
uₘ = zero(eltype(u))
for k 1:size(A, 1)
uₘ += factor * v[k] * A[k, m]
end
u[m] += uₘ
end
end
Loading

0 comments on commit 7081a22

Please sign in to comment.