diff --git a/src/dispatch.jl b/src/dispatch.jl index 144dba2..eb7d0a4 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -475,28 +475,33 @@ Base.:*(A::AbstractArray, α::AbstractMutable) = A .* α # Needed for Julia v1.0, otherwise, `broadcast(*, α, A)` gives a `Array` and # not a `Symmetric`. -_mult_upper(α, A) = parent(α * LinearAlgebra.UpperTriangular(parent(A))) -_mult_lower(α, A) = parent(α * LinearAlgebra.LowerTriangular(parent(A))) +function _mult_triangle( + ::Type{T}, + x, + A::T, +) where {T<:Union{LinearAlgebra.Symmetric,LinearAlgebra.Hermitian}} + c = LinearAlgebra.sym_uplo(A.uplo) + B = if c == :U + parent(x * LinearAlgebra.UpperTriangular(parent(A))) + else + parent(x * LinearAlgebra.LowerTriangular(parent(A))) + end + # Intermediate conversion to `Matrix` is needed to work around + # https://github.com/JuliaLang/julia/issues/52895 + return T(Matrix(LinearAlgebra.Symmetric(B, c)), c) +end function Base.:*(α::Number, A::LinearAlgebra.Symmetric{<:AbstractMutable}) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Symmetric(B, c) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) end +Base.:*(A::LinearAlgebra.Symmetric{<:AbstractMutable}, α::Number) = α * A + function Base.:*( α::AbstractMutable, A::LinearAlgebra.Symmetric{<:AbstractMutable}, ) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Symmetric(Matrix(LinearAlgebra.Symmetric(B, c)),c) -end - -function Base.:*(α::AbstractMutable, A::LinearAlgebra.Symmetric) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Symmetric(Matrix(LinearAlgebra.Symmetric(B, c)),c) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) end function Base.:*( @@ -506,11 +511,18 @@ function Base.:*( return α * A end +function Base.:*(α::AbstractMutable, A::LinearAlgebra.Symmetric) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) +end + Base.:*(A::LinearAlgebra.Symmetric, α::AbstractMutable) = α * A + function Base.:*(α::Real, A::LinearAlgebra.Hermitian{<:AbstractMutable}) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Hermitian(B, c) + return _mult_triangle(LinearAlgebra.Hermitian, α, A) +end + +function Base.:*(A::LinearAlgebra.Hermitian{<:AbstractMutable}, α::Real) + return α * A end # These three have specific methods that just redirect to `Matrix{T}` which diff --git a/test/dispatch.jl b/test/dispatch.jl index 79750c1..37f18e6 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -44,33 +44,37 @@ end end end -@testset "*(::Real, ::Union{Hermitian,Symmetric})" begin +@testset "*(::Real, ::Hermitian)" begin A = DummyBigInt[1 2; 2 3] B = DummyBigInt[2 4; 4 6] - @test MA.isequal_canonical(2 * A, B) - C = LinearAlgebra.Symmetric(B) - @test MA.isequal_canonical(2 * LinearAlgebra.Symmetric(A, :U), C) - @test MA.isequal_canonical(2 * LinearAlgebra.Symmetric(A, :L), C) D = LinearAlgebra.Hermitian(B) - @test all(MA.isequal_canonical.(2 * LinearAlgebra.Hermitian(A, :L), D)) - @test all(MA.isequal_canonical.(2 * LinearAlgebra.Hermitian(A, :U), D)) -end - -@testset "*(::AbstractMutable, ::Symmetric)" begin - A = [1 2; 5 3] for s in (:L, :U) - B = DummyBigInt(2) * LinearAlgebra.Symmetric(A, s) - C = LinearAlgebra.Symmetric(DummyBigInt.(2 * A), s) - @test MA.isequal_canonical(B, C) + Ah = LinearAlgebra.Hermitian(A, s) + @test all(MA.isequal_canonical.(2 * Ah, D)) + @test all(MA.isequal_canonical.(Ah * 2, D)) end end -@testset "*(::AbstractMutable, ::Symmetric{<:AbstractMutable})" begin - A = DummyBigInt[1 2; 5 3] - for s in (:L, :U) - B = DummyBigInt(2) * LinearAlgebra.Symmetric(A, s) - C = LinearAlgebra.Symmetric(2 * A, s) - @test MA.isequal_canonical(B, C) +@testset "*(::AbstractMutable, ::Symmetric)" begin + for A in ([1 2; 5 3], DummyBigInt[1 2; 5 3]) + for x in (2, DummyBigInt(2)) + for s in (:L, :U) + # *(::AbstractMutable, ::Symmetric) + B = LinearAlgebra.Symmetric(A, s) + C = LinearAlgebra.Symmetric(x * A, s) + D = x * B + @test D isa LinearAlgebra.Symmetric + @test MA.isequal_canonical(D, C) + @test MA.isequal_canonical(D + D, 2 * D) + # *(::Symmetric, ::AbstractMutable) + B = LinearAlgebra.Symmetric(A, s) + C = LinearAlgebra.Symmetric(A * x, s) + D = B * x + @test D isa LinearAlgebra.Symmetric + @test MA.isequal_canonical(D, C) + @test MA.isequal_canonical(D + D, 2 * D) + end + end end end