diff --git a/src/MutableArithmetics.jl b/src/MutableArithmetics.jl index 42b58714..bcf4cfe1 100644 --- a/src/MutableArithmetics.jl +++ b/src/MutableArithmetics.jl @@ -107,6 +107,9 @@ end function isequal_canonical(x::LinearAlgebra.Tridiagonal, y::LinearAlgebra.Tridiagonal) return isequal_canonical(x.dl, y.dl) && isequal_canonical(x.d, y.d) && isequal_canonical(x.du, y.du) end +function isequal_canonical(x::SparseMat, y::SparseMat) + return x.m == y.m && x.n == y.n && isequal_canonical(x.colptr, y.colptr) && isequal_canonical(x.rowval, y.rowval) && isequal_canonical(x.nzval, y.nzval) +end include("rewrite.jl") include("dispatch.jl") diff --git a/src/Test/sparse.jl b/src/Test/sparse.jl index f1f7aa4e..5488679c 100644 --- a/src/Test/sparse.jl +++ b/src/Test/sparse.jl @@ -15,6 +15,21 @@ function sparse_linear_test(X11, X23, Xd) 0 0 4X23 0 0 0] + function _test_broadcast(A, B) + @test_rewrite(A .* B) + @test_rewrite(B .* A) + @test_rewrite(A .+ B) + @test_rewrite(B .+ A) + @test_rewrite(A .- B) + @test_rewrite(B .- A) + end + + _test_broadcast(Xd, X) + _test_broadcast(Xd, Y) + _test_broadcast(Yd, X) + _test_broadcast(Yd, Y) + _test_broadcast(X, Y) + add_test(Xd, Yd) add_test(Xd, Y) add_test(Xd, Xd) diff --git a/src/broadcast.jl b/src/broadcast.jl index cc2a62a1..8fe978b8 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -1,7 +1,7 @@ -function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Eltype}) where {N, Eltype} +function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Base.HasShape{N}, ::Type{Eltype}) where {N, Eltype} return Array{Eltype, N} end -function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Bool}) where N +function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Base.HasShape{N}, ::Type{Bool}) where N return BitArray{N} end @@ -11,12 +11,18 @@ combine_styles(c::Type) = Broadcast.result_style(Broadcast.BroadcastStyle(c)) combine_styles(c1::Type, c2::Type) = Broadcast.result_style(combine_styles(c1), combine_styles(c2)) @inline combine_styles(c1::Type, c2::Type, cs::Vararg{Type, N}) where N = Broadcast.result_style(combine_styles(c1), combine_styles(c2, cs...)) +combine_shapes(s) = s +combine_2_shapes(s1::Base.HasShape{N}, s2::Base.HasShape{M}) where {N, M} = Base.HasShape{max(N, M)}() +combine_shapes(s1, s2, args::Vararg{Any, N}) where {N} = combine_shapes(combine_2_shapes(s1, s2), args...) +_shape(T) = Base.HasShape{ndims(T)}() +combine_sizes(args::Vararg{Any, N}) where {N} = combine_shapes(_shape.(args)...) + function promote_broadcast(op::Function, args::Vararg{Any, N}) where N # FIXME we could use `promote_operation` instead as # `combine_eltypes` uses `return_type` hence it may return a non-concrete type # and we do not handle that case. T = Base.Broadcast.combine_eltypes(op, args) - return broadcasted_type(combine_styles(args...), T) + return broadcasted_type(combine_styles(args...), combine_sizes(args...), T) end """ diff --git a/src/sparse_arrays.jl b/src/sparse_arrays.jl index 39b033ef..9a291e5e 100644 --- a/src/sparse_arrays.jl +++ b/src/sparse_arrays.jl @@ -220,3 +220,12 @@ function mutable_operate!(::typeof(add_mul), ret::SparseMat{T}, α::Vararg{Union{T, Scaling}, N}) where {T, N} mutable_operate!(add_mul, ret, copy(A), B, α...) end + +# This `BroadcastStyle` is used when there is a mix of sparse arrays and dense arrays. +# The result is a sparse array. +function broadcasted_type(::SparseArrays.HigherOrderFns.PromoteToSparse, ::Base.HasShape{1}, ::Type{Eltype}) where Eltype + return SparseArrays.SparseVector{Eltype, Int} +end +function broadcasted_type(::SparseArrays.HigherOrderFns.PromoteToSparse, ::Base.HasShape{2}, ::Type{Eltype}) where Eltype + return SparseMat{Eltype, Int} +end