From cbf8d636d898b6058d8e5f564e1533af895e6a5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 24 Apr 2024 09:32:20 +0200 Subject: [PATCH] Mutable addition and substraction of sparse arrays (#281) * Mutable addition and substraction of sparse arrays * Fix * Fix * Fix --- src/implementations/LinearAlgebra.jl | 24 +++++++++++++++------- test/matmul.jl | 30 +++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 2a14178..0caddd7 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -89,13 +89,17 @@ function _check_dims(A, B) return end -function operate!(op::Union{typeof(+),typeof(-)}, A::Array, B::AbstractArray) +function operate!( + op::Union{typeof(+),typeof(-)}, + A::AbstractArray, + B::AbstractArray, +) _check_dims(A, B) return broadcast!(op, A, B) end function operate_to!( - output::Array, + output::AbstractArray, op::Union{typeof(+),typeof(-)}, A::AbstractArray, B::AbstractArray, @@ -116,7 +120,7 @@ end # We call `scaling_to_number` as `UniformScaling` do not support broadcasting function operate!( op::AddSubMul, - A::Array, + A::AbstractArray, B::AbstractArray, α::Vararg{Scaling,M}, ) where {M} @@ -126,7 +130,7 @@ end function operate!( op::AddSubMul, - A::Array, + A::AbstractArray, α::Scaling, B::AbstractArray, β::Vararg{Scaling,M}, @@ -137,7 +141,7 @@ end function operate!( op::AddSubMul, - A::Array, + A::AbstractArray, α1::Scaling, α2::Scaling, B::AbstractArray, @@ -156,11 +160,17 @@ end # Fallback, we may be able to be more efficient in more cases by adding more # specialized methods. -function operate!(op::AddSubMul, A::Array, x, y) +function operate!(op::AddSubMul, A::AbstractArray, x, y) return operate!(op, A, x * y) end -function operate!(op::AddSubMul, A::Array, x, y, args::Vararg{Any,N}) where {N} +function operate!( + op::AddSubMul, + A::AbstractArray, + x, + y, + args::Vararg{Any,N}, +) where {N} @assert N > 0 return operate!(op, A, x, *(y, args...)) end diff --git a/test/matmul.jl b/test/matmul.jl index ea7a045..fdbf9fe 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -9,7 +9,7 @@ import MutableArithmetics as MA struct CustomArray{T,N} <: AbstractArray{T,N} end -import LinearAlgebra +import LinearAlgebra, SparseArrays function dot_test(x, y) @test MA.operate(LinearAlgebra.dot, x, y) == LinearAlgebra.dot(x, y) @@ -119,6 +119,12 @@ end "Cannot sum or substract a matrix of axes `$(axes(B))` into matrix of axes `$(axes(A))`, expected axes `$(axes(B))`.", ) @test_throws err MA.operate!(+, A, B) + A = SparseArrays.spzeros(2) + B = SparseArrays.spzeros(2, 1) + err = DimensionMismatch( + "Cannot sum or substract a matrix of axes `$(axes(B))` into matrix of axes `$(axes(A))`, expected axes `$(axes(B))`.", + ) + @test_throws err MA.operate!(+, A, B) output = zeros(2) A = zeros(2, 1) B = zeros(2, 1) @@ -126,6 +132,13 @@ end "Cannot sum or substract matrices of axes `$(axes(A))` and `$(axes(B))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(B))`.", ) @test_throws err MA.operate_to!(output, +, A, B) + output = SparseArrays.spzeros(2) + A = SparseArrays.spzeros(2, 1) + B = SparseArrays.spzeros(2, 1) + err = DimensionMismatch( + "Cannot sum or substract matrices of axes `$(axes(A))` and `$(axes(B))` into a matrix of axes `$(axes(output))`, expected axes `$(axes(B))`.", + ) + @test_throws err MA.operate_to!(output, +, A, B) end @testset "unsupported_product" begin unsupported_product() @@ -454,6 +467,21 @@ function test_array_sum(::Type{T}) where {T} return end +function test_sparse_vector_sum(::Type{T}) where {T} + x = SparseArrays.sparsevec([1, 3], T[5, 7]) + y = copy(x) + z = copy(y) + alloc_test(() -> MA.operate!(+, y, z), 0) + alloc_test(() -> MA.operate!(-, y, z), 0) + alloc_test(() -> MA.add!!(y, z), 0) + alloc_test(() -> MA.sub!!(y, z), 0) + alloc_test(() -> MA.operate_to!(x, +, y, z), 0) + alloc_test(() -> MA.operate_to!(x, -, y, z), 0) + alloc_test(() -> MA.add_to!!(x, y, z), 0) + alloc_test(() -> MA.sub_to!!(x, y, z), 0) + return +end + @testset "Array sum" begin test_array_sum(Int) end