diff --git a/src/dispatch.jl b/src/dispatch.jl index 725c7b3f..3ffa4e36 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -343,199 +343,35 @@ end # `LinearAlgebra.mul!` which prevents us from using mutability of the # arithmetic. For this reason we intercept the calls and redirect them to `mul`. -# A few are overwritten below but many more need to be redirected to `mul` in -# `linalg.jl`. - -Base.:*(A::_SparseMat{<:AbstractMutable}, x::StridedVector) = mul(A, x) - -Base.:*(A::_SparseMat, x::StridedVector{<:AbstractMutable}) = mul(A, x) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -# These six methods are needed on Julia v1.2 and earlier -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - x::StridedVector, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - x::StridedVector, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - x::StridedVector{<:AbstractMutable}, -) - return mul(A, x) -end - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::_SparseMat{<:Any}, B::_SparseMat{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::_SparseMat{<:AbstractMutable}, B::_SparseMat{<:Any}) = mul(A, B) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::_SparseMat{<:Any}, - B::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::_SparseMat{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::StridedMatrix{<:AbstractMutable}, - B::_SparseMat{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::StridedMatrix{<:Any}, B::_SparseMat{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::StridedMatrix{<:AbstractMutable}, B::_SparseMat{<:Any}) = mul(A, B) - -function Base.:*( - A::_SparseMat{<:AbstractMutable}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -Base.:*(A::_SparseMat{<:Any}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B) - -Base.:*(A::_SparseMat{<:AbstractMutable}, B::StridedMatrix{<:Any}) = mul(A, B) - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:Any,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Adjoint{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:Any}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:Any,<:_SparseMat}, - B::StridedMatrix{<:AbstractMutable}, -) - return mul(A, B) -end - -function Base.:*( - A::LinearAlgebra.Transpose{<:AbstractMutable,<:_SparseMat}, - B::StridedMatrix{<:Any}, -) - return mul(A, B) +const _LinearAlgebraWrappers = ( + LinearAlgebra.Adjoint, + LinearAlgebra.Transpose, + # TODO(odow): we could expand these overloads to other LinearAlgebra types. + # LinearAlgebra.Symmetric, + # LinearAlgebra.Hermitian, + # LinearAlgebra.Diagonal, + # LinearAlgebra.LowerTriangular, + # LinearAlgebra.UpperTriangular, + # LinearAlgebra.UnitLowerTriangular, + # LinearAlgebra.UnitUpperTriangular, +) + +const _MatrixLike = vcat( + Any[T -> LA{<:T,<:_SparseMat} for LA in _LinearAlgebraWrappers], + Any[T->_SparseMat{<:T}, T->StridedMatrix{<:T}], +) + +for f_A in _MatrixLike, f_B in vcat(_MatrixLike, T -> StridedVector{<:T}) + A, mut_A = f_A(Any), f_A(AbstractMutable) + B, mut_B = f_B(Any), f_B(AbstractMutable) + if A <: StridedMatrix && B <: StridedMatrix + continue + end + @eval begin + Base.:*(a::$(mut_A), b::$(B)) = mul(a, b) + Base.:*(a::$(A), b::$(mut_B)) = mul(a, b) + Base.:*(a::$(mut_A), b::$(mut_B)) = mul(a, b) + end end const StridedMaybeAdjOrTransMat{T} = Union{