From 2a55083e38b75726e1b75459a3fdd8ff5c98b5a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Mon, 10 Jun 2024 14:10:20 +0200 Subject: [PATCH] Mutable addition of one array into result (#288) * Mutable addition of one array into result * Add tests * Add test --- src/implementations/LinearAlgebra.jl | 18 ++++++++++++++++++ test/matmul.jl | 14 ++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 0caddd7..de806ad 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -98,6 +98,24 @@ function operate!( return broadcast!(op, A, B) end +function operate_to!( + output::AbstractArray, + op::Union{typeof(+),typeof(-)}, + A::AbstractArray, +) + if axes(output) != axes(A) + throw( + DimensionMismatch( + "Cannot sum or substract a matrix of axes `$(axes(A))`" * + " into a matrix of axes `$(axes(output))`, expected" * + " axes `$(axes(A))`.", + ), + ) + end + # We don't have `MA.broadcast_to!` as it would be exactly `Base.broadcast!`. + return Base.broadcast!(op, output, A) +end + function operate_to!( output::AbstractArray, op::Union{typeof(+),typeof(-)}, diff --git a/test/matmul.jl b/test/matmul.jl index fdbf9fe..90fc363 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -139,6 +139,11 @@ end "Cannot sum or substract matrices of axes `$(axes(A))` and `$(axes(B))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(B))`.", ) @test_throws err MA.operate_to!(output, +, A, B) + err = DimensionMismatch( + "Cannot sum or substract a matrix of axes `$(axes(A))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(A))`.", + ) + @test_throws err MA.operate_to!(output, +, A) + @test_throws err MA.operate_to!(output, -, A) end @testset "unsupported_product" begin unsupported_product() @@ -471,17 +476,14 @@ function test_sparse_vector_sum(::Type{T}) where {T} x = SparseArrays.sparsevec([1, 3], T[5, 7]) y = copy(x) z = copy(y) - alloc_test(() -> MA.operate!(+, y, z), 0) - alloc_test(() -> MA.operate!(-, y, z), 0) - alloc_test(() -> MA.add!!(y, z), 0) - alloc_test(() -> MA.sub!!(y, z), 0) alloc_test(() -> MA.operate_to!(x, +, y, z), 0) alloc_test(() -> MA.operate_to!(x, -, y, z), 0) - alloc_test(() -> MA.add_to!!(x, y, z), 0) - alloc_test(() -> MA.sub_to!!(x, y, z), 0) + alloc_test(() -> MA.operate_to!(x, +, y), 0) + alloc_test(() -> MA.operate_to!(x, -, y), 0) return end @testset "Array sum" begin test_array_sum(Int) + test_sparse_vector_sum(Int) end