From dfce02aa021c514fe2b25c4953067c29adfca795 Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 5 Mar 2024 09:14:38 +1300 Subject: [PATCH] Fix *(::AbstractMutable, ::Symmetric) (#268) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mateus Araújo --- src/dispatch.jl | 51 ++++++++++++++++++++++++++++++++++++++++-------- test/dispatch.jl | 36 +++++++++++++++++++++++++++------- 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 5ba9d607..32a5da68 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -475,19 +475,54 @@ Base.:*(A::AbstractArray, α::AbstractMutable) = A .* α # Needed for Julia v1.0, otherwise, `broadcast(*, α, A)` gives a `Array` and # not a `Symmetric`. -_mult_upper(α, A) = parent(α * LinearAlgebra.UpperTriangular(parent(A))) -_mult_lower(α, A) = parent(α * LinearAlgebra.LowerTriangular(parent(A))) +function _mult_triangle( + ::Type{T}, + x, + A::T, +) where {T<:Union{LinearAlgebra.Symmetric,LinearAlgebra.Hermitian}} + c = LinearAlgebra.sym_uplo(A.uplo) + B = if c == :U + parent(x * LinearAlgebra.UpperTriangular(parent(A))) + else + parent(x * LinearAlgebra.LowerTriangular(parent(A))) + end + # Intermediate conversion to `Matrix` is needed to work around + # https://github.com/JuliaLang/julia/issues/52895 + return T(Matrix(T(B, c)), c) +end function Base.:*(α::Number, A::LinearAlgebra.Symmetric{<:AbstractMutable}) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Symmetric(B, c) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) end +Base.:*(A::LinearAlgebra.Symmetric{<:AbstractMutable}, α::Number) = α * A + +function Base.:*( + α::AbstractMutable, + A::LinearAlgebra.Symmetric{<:AbstractMutable}, +) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) +end + +function Base.:*( + A::LinearAlgebra.Symmetric{<:AbstractMutable}, + α::AbstractMutable, +) + return α * A +end + +function Base.:*(α::AbstractMutable, A::LinearAlgebra.Symmetric) + return _mult_triangle(LinearAlgebra.Symmetric, α, A) +end + +Base.:*(A::LinearAlgebra.Symmetric, α::AbstractMutable) = α * A + function Base.:*(α::Real, A::LinearAlgebra.Hermitian{<:AbstractMutable}) - c = LinearAlgebra.sym_uplo(A.uplo) - B = c == :U ? _mult_upper(α, A) : _mult_lower(α, A) - return LinearAlgebra.Hermitian(B, c) + return _mult_triangle(LinearAlgebra.Hermitian, α, A) +end + +function Base.:*(A::LinearAlgebra.Hermitian{<:AbstractMutable}, α::Real) + return α * A end # These three have specific methods that just redirect to `Matrix{T}` which diff --git a/test/dispatch.jl b/test/dispatch.jl index 3c8eade2..37f18e65 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -44,16 +44,38 @@ end end end -@testset "*(::Real, ::Union{Hermitian,Symmetric})" begin +@testset "*(::Real, ::Hermitian)" begin A = DummyBigInt[1 2; 2 3] B = DummyBigInt[2 4; 4 6] - @test MA.isequal_canonical(2 * A, B) - C = LinearAlgebra.Symmetric(B) - @test MA.isequal_canonical(2 * LinearAlgebra.Symmetric(A, :U), C) - @test MA.isequal_canonical(2 * LinearAlgebra.Symmetric(A, :L), C) D = LinearAlgebra.Hermitian(B) - @test all(MA.isequal_canonical.(2 * LinearAlgebra.Hermitian(A, :L), D)) - @test all(MA.isequal_canonical.(2 * LinearAlgebra.Hermitian(A, :U), D)) + for s in (:L, :U) + Ah = LinearAlgebra.Hermitian(A, s) + @test all(MA.isequal_canonical.(2 * Ah, D)) + @test all(MA.isequal_canonical.(Ah * 2, D)) + end +end + +@testset "*(::AbstractMutable, ::Symmetric)" begin + for A in ([1 2; 5 3], DummyBigInt[1 2; 5 3]) + for x in (2, DummyBigInt(2)) + for s in (:L, :U) + # *(::AbstractMutable, ::Symmetric) + B = LinearAlgebra.Symmetric(A, s) + C = LinearAlgebra.Symmetric(x * A, s) + D = x * B + @test D isa LinearAlgebra.Symmetric + @test MA.isequal_canonical(D, C) + @test MA.isequal_canonical(D + D, 2 * D) + # *(::Symmetric, ::AbstractMutable) + B = LinearAlgebra.Symmetric(A, s) + C = LinearAlgebra.Symmetric(A * x, s) + D = B * x + @test D isa LinearAlgebra.Symmetric + @test MA.isequal_canonical(D, C) + @test MA.isequal_canonical(D + D, 2 * D) + end + end + end end @testset "*(::Complex, ::Hermitian)" begin