From 6f0e6994557a732947d343908059e8e9b07ac22a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 17 Apr 2024 02:03:25 +0200 Subject: [PATCH] Add corresponding operate_to! method for Matrix-scalar product (#273) --- src/dispatch.jl | 25 +++++++++++++++++++++++++ test/dispatch.jl | 21 +++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/src/dispatch.jl b/src/dispatch.jl index b6bdf92..2960906 100644 --- a/src/dispatch.jl +++ b/src/dispatch.jl @@ -468,10 +468,35 @@ end # +(::SparseMatrixCSC) is not defined for generic types in Base. Base.:+(A::AbstractArray{<:AbstractMutable}) = A +# `Base.*(::AbstractArray, α)` is only defined if `α isa Number` +# Currently, mutable types are scalar elements (e.g. JuMP expression, +# MOI functions or polynomials) so broadcasting is the right dispatch. +# If this causes issues in the future, e.g., because a user define a non-scalar +# subtype of `AbstractMutable`, we might want to check that +# `ndims` is zero and error otherwise. + Base.:*(α::AbstractMutable, A::AbstractArray) = α .* A Base.:*(A::AbstractArray, α::AbstractMutable) = A .* α +function operate_to!( + output::AbstractArray, + ::typeof(*), + v::AbstractArray, + α::Union{Number,AbstractMutable}, +) + return Base.broadcast!(*, output, v, α) +end + +function operate_to!( + output::AbstractArray, + ::typeof(*), + α::Union{Number,AbstractMutable}, + v::AbstractArray, +) + return Base.broadcast!(*, output, α, v) +end + # Needed for Julia v1.0, otherwise, `broadcast(*, α, A)` gives a `Array` and # not a `Symmetric`. diff --git a/test/dispatch.jl b/test/dispatch.jl index 37f18e6..84ba729 100644 --- a/test/dispatch.jl +++ b/test/dispatch.jl @@ -85,3 +85,24 @@ end @test 2im * B == C @test C isa Matrix{Complex{BigInt}} end + +@testset "operate_to!(::Array, ::typeof(*), ::AbstractMutable, ::Array)" begin + for A in ([1 2; 5 3], DummyBigInt[1 2; 5 3]) + for x in (2, DummyBigInt(2)) + # operate_to!(::Array, *, ::AbstractMutable, ::Array) + B = x * A + C = zero(B) + D = MA.operate_to!(C, *, x, A) + @test C === D + @test typeof(B) == typeof(C) + @test MA.isequal_canonical(B, C) + # operate_to!(::Array, *, ::Array, ::AbstractMutable) + B = A * x + C = zero(B) + D = MA.operate_to!(C, *, A, x) + @test C === D + @test typeof(B) == typeof(C) + @test MA.isequal_canonical(B, C) + end + end +end