From 0a244ac1c79812b3516905b131fa5d88797b11e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 16 Apr 2024 11:50:17 +0200 Subject: [PATCH] Move method and add tests --- src/dispatch.jl | 25 +++++++++++++++++++++++++ src/implementations/LinearAlgebra.jl | 9 --------- test/dispatch.jl | 21 +++++++++++++++++++++ 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/src/dispatch.jl b/src/dispatch.jl index 32a5da6..730e277 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -468,10 +468,35 @@ end # +(::SparseMatrixCSC) is not defined for generic types in Base. Base.:+(A::AbstractArray{<:AbstractMutable}) = A +# `Base.*(::AbstractArray, α)` is only defined if `α isa Number` +# Currently, mutable types are scalar elements (e.g. JuMP expression, +# MOI functions or polynomials) so broadcasting is the right dispatch. +# If this causes issues in the future, e.g., because a user define a non-scalar +# subtype of `AbstractMutable`, we might want to check that +# `ndims` is zero and error otherwise. + Base.:*(α::AbstractMutable, A::AbstractArray) = α .* A Base.:*(A::AbstractArray, α::AbstractMutable) = A .* α +function operate_to!( + output::AbstractArray, + ::typeof(*), + v::AbstractArray, + α::Union{Number,AbstractMutable}, +) + return Base.broadcast!(*, output, v, α) +end + +function operate_to!( + output::AbstractArray, + ::typeof(*), + α::Union{Number,AbstractMutable}, + v::AbstractArray, +) + return Base.broadcast!(*, output, α, v) +end + # Needed for Julia v1.0, otherwise, `broadcast(*, α, A)` gives a `Array` and # not a `Symmetric`. diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 30b87b7..730b7cd 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -155,15 +155,6 @@ function operate_to!( return Base.broadcast!(op, output, A, B) end -function operate_to!( - output::AbstractArray, - ::typeof(*), - v::AbstractArray, - α::Number, -) - return LinearAlgebra.mul!(output, v, α) -end - # Product function similar_array_type( diff --git a/test/dispatch.jl b/test/dispatch.jl index 37f18e6..84ba729 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -85,3 +85,24 @@ end @test 2im * B == C @test C isa Matrix{Complex{BigInt}} end + +@testset "operate_to!(::Array, ::typeof(*), ::AbstractMutable, ::Array)" begin + for A in ([1 2; 5 3], DummyBigInt[1 2; 5 3]) + for x in (2, DummyBigInt(2)) + # operate_to!(::Array, *, ::AbstractMutable, ::Array) + B = x * A + C = zero(B) + D = MA.operate_to!(C, *, x, A) + @test C === D + @test typeof(B) == typeof(C) + @test MA.isequal_canonical(B, C) + # operate_to!(::Array, *, ::Array, ::AbstractMutable) + B = A * x + C = zero(B) + D = MA.operate_to!(C, *, A, x) + @test C === D + @test typeof(B) == typeof(C) + @test MA.isequal_canonical(B, C) + end + end +end