Skip to content

Commit

Permalink
Specialize matmul dest for structured matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Dec 10, 2023
1 parent 84cfe04 commit e7361e7
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 23 deletions.
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,14 @@ _init_eltype(op, ::Type{TA}, ::Type{TB}) where {TA,TB} =
_initarray(op, ::Type{TA}, ::Type{TB}, C) where {TA,TB} =
similar(C, _init_eltype(op, TA, TB), size(C))

# destination type for matmul
matprod_dest(A::StructuredMatrix, B::StructuredMatrix, TS) = similar(B, TS, size(B))
matprod_dest(A, B::StructuredMatrix, TS) = similar(A, TS, size(A))
matprod_dest(A::StructuredMatrix, B, TS) = similar(B, TS, size(B))
matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = similar(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = similar(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = similar(B, TS)

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
# which is required by LAPACK but not SuiteSparse which allows real-complex solves in some cases. Hence,
Expand Down
9 changes: 0 additions & 9 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,6 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

(*)(A::AbstractMatrix, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D)
(*)(A::HermOrSym, D::Diagonal) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(D.diag), eltype(A))), D, A)
(*)(D::Diagonal, A::HermOrSym) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ julia> [1 1; 0 1] * [1 0; 1 1]
"""
function (*)(A::AbstractMatrix, B::AbstractMatrix)
TS = promote_op(matprod, eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B)
mul!(matprod_dest(A, B, TS), A, B)
end

matprod_dest(A, B, TS) = similar(B, TS, (size(A, 1), size(B, 2)))

# optimization for dispatching to BLAS, e.g. *(::Matrix{Float32}, ::Matrix{Float64})
# but avoiding the case *(::Matrix{<:BlasComplex}, ::Matrix{<:BlasReal})
# which is better handled by reinterpreting rather than promotion
Expand Down
13 changes: 0 additions & 13 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1471,12 +1471,6 @@ function *(A::AbstractTriangular, B::AbstractTriangular)
end

for mat in (:AbstractVector, :AbstractMatrix)
### Multiplication with triangle to the left and hence rhs cannot be transposed.
@eval function *(A::AbstractTriangular, B::$mat)
require_one_based_indexing(B)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(B, TAB, size(B)), A, B)
end
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
@eval function \(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
require_one_based_indexing(B)
Expand All @@ -1502,13 +1496,6 @@ for mat in (:AbstractVector, :AbstractMatrix)
_rdiv!(similar(A, TAB, size(A)), A, B)
end
end
### Multiplication with triangle to the right and hence lhs cannot be transposed.
# Only for AbstractMatrix, hence outside the above loop.
function *(A::AbstractMatrix, B::AbstractTriangular)
require_one_based_indexing(A)
TAB = _init_eltype(*, eltype(A), eltype(B))
mul!(similar(A, TAB, size(A)), A, B)
end
# ambiguity resolution with definitions in matmul.jl
*(v::AdjointAbsVec, A::AbstractTriangular) = adjoint(adjoint(A) * v.parent)
*(v::TransposeAbsVec, A::AbstractTriangular) = transpose(transpose(A) * v.parent)
Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1235,4 +1235,12 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
D = Diagonal([1:2;])
@test S * D == A * D
@test D * S == D * A
end

end # module TestDiagonal
12 changes: 12 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using LinearAlgebra: BlasFloat, errorbounds, full!, transpose!,
UnitUpperTriangular, UnitLowerTriangular,
mul!, rdiv!, rmul!, lmul!

const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays

debug && println("Triangular matrices")

n = 9
Expand Down Expand Up @@ -866,4 +870,12 @@ end
end
end

@testset "avoid matmul ambiguities with ::MyMatrix * ::AbstractMatrix" begin
A = [i+j for i in 1:2, j in 1:2]
S = SizedArrays.SizedArray{(2,2)}(A)
U = UpperTriangular(ones(2,2))
@test S * U == A * U
@test U * S == U * A
end

end # module TestTriangular
5 changes: 5 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,9 @@ function *(S1::SizedArrayLike, S2::SizedArrayLike)
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end

# deliberately wide method definition to ensure that this doesn't lead to ambiguities with
# structured matrices
*(S1::SizedArrayLike, M::AbstractMatrix) = _data(S1) * M

end

0 comments on commit e7361e7

Please sign in to comment.