Skip to content

Commit

Permalink
Determine correct eltype in sparse map (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Aug 1, 2022
1 parent 94c0a1c commit 7bf0c5c
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 47 deletions.
8 changes: 8 additions & 0 deletions src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
sprand, sprandn, spzeros, nnz, permute, findnz, fkeep!, ftranspose!,
sparse_hcat, sparse_vcat, sparse_hvcat

# helper function needed in sparsematrix, sparsevector and higherorderfns
@inline _iszero(x) = x == 0
@inline _iszero(x::Number) = Base.iszero(x)
@inline _iszero(x::AbstractArray) = Base.iszero(x)
@inline _isnotzero(x) = (x != 0) !== false # like `x != 0`, but handles `x::Missing`
@inline _isnotzero(x::Number) = !iszero(x)
@inline _isnotzero(x::AbstractArray) = !iszero(x)

include("abstractsparse.jl")
include("sparsematrix.jl")
include("sparseconvert.jl")
Expand Down
32 changes: 15 additions & 17 deletions src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import Base: map, map!, broadcast, copy, copyto!
using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrixCSC,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange, spzeros,
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros, rowvals, getcolptr, widelength
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros,
rowvals, getcolptr, widelength, _iszero, _isnotzero
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using LinearAlgebra

Expand Down Expand Up @@ -202,9 +203,6 @@ end
# helper functions for map[!]/broadcast[!] entry points (and related methods below)
@inline _sumnnzs(A) = nnz(A)
@inline _sumnnzs(A, Bs...) = nnz(A) + _sumnnzs(Bs...)
@inline _iszero(x) = x == 0
@inline _iszero(x::Number) = Base.iszero(x)
@inline _iszero(x::AbstractArray) = Base.iszero(x)
@inline _zeros_eltypes(A) = (zero(eltype(A)),)
@inline _zeros_eltypes(A, Bs...) = (zero(eltype(A)), _zeros_eltypes(Bs...)...)
@inline _promote_indtype(A) = indtype(A)
Expand Down Expand Up @@ -244,7 +242,7 @@ function _map_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat) where Tf
setcolptr!(C, j, Ck)
for Ak in colrange(A, j)
Cx = f(storedvals(A)[Ak])
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Ck + nnz(A) - (Ak - 1)))
storedinds(C)[Ck] = storedinds(A)[Ak]
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -325,7 +323,7 @@ function _map_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::SparseVe
# cases are equally or more likely than the Ai < Bi and Bi < Ai cases. Hence
# the ordering of the conditional chain above differs from that in the
# corresponding broadcast code (below).
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -386,7 +384,7 @@ function _map_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N})
while activerow < rowsentinel
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
Cx = f(vals...)
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Int(min(widelength(C), Ck + _sumnnzs(As...) - (sum(ks) - N)))))
storedinds(C)[Ck] = activerow
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -477,7 +475,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat) where
bccolrangejA = numcols(A) == 1 ? colrange(A, 1) : colrange(A, j)
for Ak in bccolrangejA
Cx = f(storedvals(A)[Ak])
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A)))
storedinds(C)[Ck] = storedinds(A)[Ak]
storedvals(C)[Ck] = Cx
Expand All @@ -496,7 +494,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat) where
# contains a nonzero value x but f(Ax) is nonetheless zero, so we need store
# nothing in C's jth column. if to the contrary fofAx is nonzero, then we must
# densely populate C's jth column with fofAx.
if !_iszero(fofAx)
if _isnotzero(fofAx)
for Ci::indtype(C) in 1:numrows(C)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A)))
storedinds(C)[Ck] = Ci
Expand Down Expand Up @@ -599,7 +597,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
# pattern) the Ai < Bi and Bi < Ai cases are equally or more likely than the
# Ai == Bi and termination cases. Hence the ordering of the conditional
# chain above differs from that in the corresponding map code.
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand All @@ -616,7 +614,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
Ax = Ak < stopAk ? storedvals(A)[Ak] : zero(eltype(A))
Bx = Bk < stopBk ? storedvals(B)[Bk] : zero(eltype(B))
Cx = f(Ax, Bx)
if !_iszero(Cx)
if _isnotzero(Cx)
for Ci::indtype(C) in 1:numrows(C)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
Expand All @@ -638,7 +636,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
# B's jth column without storing every entry in C's jth column
while Bk < stopBk
Cx = f(Ax, storedvals(B)[Bk])
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = storedinds(B)[Bk]
storedvals(C)[Ck] = Cx
Expand All @@ -657,7 +655,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
else
Cx = fvAzB
end
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand All @@ -679,7 +677,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
# A's jth column without storing every entry in C's jth column
while Ak < stopAk
Cx = f(storedvals(A)[Ak], Bx)
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = storedinds(A)[Ak]
storedvals(C)[Ck] = Cx
Expand All @@ -698,7 +696,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Sp
else
Cx = fzAvB
end
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -888,7 +886,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMa
while activerow < rowsentinel
args, ks, rows = _fusedupdatebc_all(rowsentinel, activerow, rows, defargs, ks, stopks, As)
Cx = f(args...)
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), As)))
storedinds(C)[Ck] = activerow
storedvals(C)[Ck] = Cx
Expand All @@ -905,7 +903,7 @@ function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMa
else
Cx = defaultCx
end
if !_iszero(Cx)
if _isnotzero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), As)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down
4 changes: 2 additions & 2 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ function SparseMatrixCSC{Tv,Ti}(M::AbstractMatrix) where {Tv,Ti}
i = 0
for v in M
i += 1
if !iszero(v)
if _isnotzero(v)
push!(I, i)
push!(V, v)
end
Expand Down Expand Up @@ -2728,7 +2728,7 @@ function _setindex_scalar!(A::AbstractSparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _
end
# Column j does not contain entry A[i,j]. If v is nonzero, insert entry A[i,j] = v
# and return. If to the contrary v is zero, then simply return.
if !iszero(v)
if _isnotzero(v)
nz = getcolptr(A)[size(A, 2)+1]
# throw exception before state is partially modified
!isbitstype(Ti) || nz < typemax(Ti) ||
Expand Down
65 changes: 41 additions & 24 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,24 @@ using LinearAlgebra: _SpecialArrays, _DenseConcatGroup
SparseVector{Tv,Ti<:Integer} <: AbstractSparseVector{Tv,Ti}
Vector type for storing sparse vectors. Can be created by passing the length of the vector,
a *sorted* vector of non-zero indices, and a vector of non-zero values.
a *sorted* vector of non-zero indices, and a vector of non-zero values.
For instance, the vector `[5, 6, 0, 7]` can be represented as
For instance, the Vector `[5, 6, 0, 7]` can be represented as
```julia
SparseVector(4, [1, 2, 4], [5, 6, 7])
```
This indicates that the index 1 is 5, the index 2 is 6, the index 3 is `zero(Int)`, and index 4 is 7.
It may be more convenient to create sparse vectors directly from dense vectors using `sparse` as
This indicates that the element at index 1 is 5, at index 2 is 6, at index 3 is `zero(Int)`,
and at index 4 is 7.
It may be more convenient to create sparse vectors directly from dense vectors using `sparse` as
```julia
sparse([5, 6, 0, 7])
```
yeilds the same sparse vector.
```
yields the same sparse vector.
"""
struct SparseVector{Tv,Ti<:Integer} <: AbstractSparseVector{Tv,Ti}
n::Ti # Length of the sparse vector
Expand Down Expand Up @@ -344,7 +349,7 @@ function setindex!(x::SparseVector{Tv,Ti}, v::Tv, i::Ti) where {Tv,Ti<:Integer}
if 1 <= k <= m && nzind[k] == i # i found
nzval[k] = v
else # i not found
if !iszero(v)
if _isnotzero(v)
insert!(nzind, k, i)
insert!(nzval, k, v)
end
Expand Down Expand Up @@ -1201,7 +1206,7 @@ macro unarymap_nz2z_z2z(op, TF)
@inbounds for j = 1:m
i = xnzind[j]
v = $(op)(xnzval[j])
if v != zero(v)
if _isnotzero(v)
ir += 1
ynzind[ir] = i
ynzval[ir] = v
Expand Down Expand Up @@ -1253,7 +1258,7 @@ function _binarymap(f::Function,
y::AbstractSparseVector{Ty},
mode::Int) where {Tx,Ty}
0 <= mode <= 2 || throw(ArgumentError("Incorrect mode $mode."))
R = typeof(f(zero(Tx), zero(Ty)))
R = Base.Broadcast.combine_eltypes(f, (x, y))
n = length(x)
length(y) == n || throw(DimensionMismatch())

Expand All @@ -1268,9 +1273,6 @@ function _binarymap(f::Function,
rind = Vector{Int}(undef, cap)
rval = Vector{R}(undef, cap)
ir = 0
ix = 1
iy = 1

ir = (
mode == 0 ? _binarymap_mode_0!(f, mx, my,
xnzind, xnzval, ynzind, ynzval, rind, rval) :
Expand All @@ -1288,6 +1290,7 @@ end
function _binarymap_mode_0!(f::Function, mx::Int, my::Int,
xnzind, xnzval, ynzind, ynzval, rind, rval)
# f(nz, nz) -> nz, f(z, nz) -> z, f(nz, z) -> z
require_one_based_indexing(xnzind, ynzind, xnzval, ynzval, rind, rval)
ir = 0; ix = 1; iy = 1
@inbounds while ix <= mx && iy <= my
jx = xnzind[ix]
Expand All @@ -1310,13 +1313,14 @@ function _binarymap_mode_1!(f::Function, mx::Int, my::Int,
ynzind, ynzval::AbstractVector{Ty},
rind, rval) where {Tx,Ty}
# f(nz, nz) -> z/nz, f(z, nz) -> nz, f(nz, z) -> nz
require_one_based_indexing(xnzind, ynzind, xnzval, ynzval, rind, rval)
ir = 0; ix = 1; iy = 1
@inbounds while ix <= mx && iy <= my
jx = xnzind[ix]
jy = ynzind[iy]
if jx == jy
v = f(xnzval[ix], ynzval[iy])
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = jx; rval[ir] = v
end
ix += 1; iy += 1
Expand Down Expand Up @@ -1348,40 +1352,41 @@ function _binarymap_mode_2!(f::Function, mx::Int, my::Int,
ynzind, ynzval::AbstractVector{Ty},
rind, rval) where {Tx,Ty}
# f(nz, nz) -> z/nz, f(z, nz) -> z/nz, f(nz, z) -> z/nz
require_one_based_indexing(xnzind, ynzind, xnzval, ynzval, rind, rval)
ir = 0; ix = 1; iy = 1
@inbounds while ix <= mx && iy <= my
jx = xnzind[ix]
jy = ynzind[iy]
if jx == jy
v = f(xnzval[ix], ynzval[iy])
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = jx; rval[ir] = v
end
ix += 1; iy += 1
elseif jx < jy
v = f(xnzval[ix], zero(Ty))
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = jx; rval[ir] = v
end
ix += 1
else
v = f(zero(Tx), ynzval[iy])
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = jy; rval[ir] = v
end
iy += 1
end
end
@inbounds while ix <= mx
v = f(xnzval[ix], zero(Ty))
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = xnzind[ix]; rval[ir] = v
end
ix += 1
end
@inbounds while iy <= my
v = f(zero(Tx), ynzval[iy])
if v != zero(v)
if _isnotzero(v)
ir += 1; rind[ir] = ynzind[iy]; rval[ir] = v
end
iy += 1
Expand All @@ -1392,16 +1397,28 @@ end
# definition of a few known broadcasted/mapped binary functions — all others defer to HigherOrderFunctions

_bcast_binary_map(f, x, y, mode) = length(x) == length(y) ? _binarymap(f, x, y, mode) : HigherOrderFns._diffshape_broadcast(f, x, y)
_getmode(::typeof(+), ::Type, ::Type) = 1
_getmode(::typeof(-), ::Type, ::Type) = 1
_getmode(::typeof(*), ::Type, ::Type) = 0
_getmode(::typeof(*), ::Type{Union{Missing, T}}, ::Type) where {T} = 2
_getmode(::typeof(*), ::Type, ::Type{Union{Missing, T}}) where {T} = 2
_getmode(::typeof(*), ::Type{Union{Missing, T}}, ::Type{Union{Missing, S}}) where {T,S} = 2
_getmode(::typeof(min), ::Type, ::Type) = 2
_getmode(::typeof(max), ::Type, ::Type) = 2
for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
fun in (:+, :-) && @eval begin
# Addition and subtraction can be defined directly on the arrays (without map/broadcast)
$(fun)(x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($(fun), x, y, $mode)
end
@eval begin
map(::typeof($fun), x::AbstractSparseVector, y::AbstractSparseVector) = _binarymap($fun, x, y, $mode)
map(::typeof($fun), x::SparseVector, y::SparseVector) = _binarymap($fun, x, y, $mode)
broadcast(::typeof($fun), x::AbstractSparseVector, y::AbstractSparseVector) = _bcast_binary_map($fun, x, y, $mode)
broadcast(::typeof($fun), x::SparseVector, y::SparseVector) = _bcast_binary_map($fun, x, y, $mode)
map(::typeof($fun), x::AbstractSparseVector{Tx}, y::AbstractSparseVector{Ty}) where {Tx, Ty} =
_binarymap($fun, x, y, _getmode($fun, Tx, Ty))
map(::typeof($fun), x::SparseVector{Tx}, y::SparseVector{Ty}) where {Tx, Ty} =
_binarymap($fun, x, y, _getmode($fun, Tx, Ty))
broadcast(::typeof($fun), x::AbstractSparseVector{Tx}, y::AbstractSparseVector{Ty}) where {Tx, Ty} =
_bcast_binary_map($fun, x, y, _getmode($fun, Tx, Ty))
broadcast(::typeof($fun), x::SparseVector{Tx}, y::SparseVector{Ty}) where {Tx, Ty} =
_bcast_binary_map($fun, x, y, _getmode($fun, Tx, Ty))
end
end

Expand Down Expand Up @@ -1631,7 +1648,7 @@ function mul!(y::AbstractVector, A::_StridedOrTriangularMatrix, x::AbstractSpars
xnzval = nonzeros(x)
@inbounds for i = 1:length(xnzind)
v = xnzval[i]
if v != zero(v)
if _isnotzero(v)
j = xnzind[i]
αv = v * α
for r = 1:m
Expand Down Expand Up @@ -1765,7 +1782,7 @@ function mul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSparseVe

@inbounds for i = 1:length(xnzind)
v = xnzval[i]
if v != zero(v)
if _isnotzero(v)
αv = v * α
j = xnzind[i]
for r = Acolptr[j]:(Acolptr[j+1]-1)
Expand Down
28 changes: 28 additions & 0 deletions test/sparsematrix_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,34 @@ do33 = fill(1.,3)
end
end
end
@testset "binary operations on sparse matrices with union eltype" begin
A = sparse([1,2,1], [1,1,2], Union{Int, Missing}[1, missing, 0])
for fun in (+, -, *, min, max)
if fun in (+, -)
@test collect(skipmissing(Array(fun(A, A)))) == collect(skipmissing(Array(fun(Array(A), Array(A)))))
end
@test collect(skipmissing(Array(map(fun, A, A)))) == collect(skipmissing(map(fun, Array(A), Array(A))))
@test collect(skipmissing(Array(broadcast(fun, A, A)))) == collect(skipmissing(broadcast(fun, Array(A), Array(A))))
end
b = convert(SparseMatrixCSC{Union{Float64, Missing}}, sprandn(Float64, 20, 10, 0.2)); b[rand(1:200, 3)] .= missing
C = convert(SparseMatrixCSC{Union{Float64, Missing}}, sprandn(Float64, 20, 10, 0.9)); C[rand(1:200, 3)] .= missing
CA = Array(C)
D = convert(SparseMatrixCSC{Union{Float64, Missing}}, spzeros(Float64, 20, 10)); D[rand(1:200, 3)] .= missing
E = convert(SparseMatrixCSC{Union{Float64, Missing}}, spzeros(Float64, 20, 10))
for B in (b, C, D, E), fun in (+, -, *, min, max)
BA = Array(B)
# reverse order for opposite nonzeroinds-structure
if fun in (+, -)
@test collect(skipmissing(Array(fun(B, C)))) == collect(skipmissing(Array(fun(BA, CA))))
@test collect(skipmissing(Array(fun(C, B)))) == collect(skipmissing(Array(fun(CA, BA))))
end
@test collect(skipmissing(Array(map(fun, B, C)))) == collect(skipmissing(map(fun, BA, CA)))
@test collect(skipmissing(Array(map(fun, C, B)))) == collect(skipmissing(map(fun, CA, BA)))
@test collect(skipmissing(Array(broadcast(fun, B, C)))) == collect(skipmissing(broadcast(fun, BA, CA)))
@test collect(skipmissing(Array(broadcast(fun, C, B)))) == collect(skipmissing(broadcast(fun, CA, BA)))
end
end

end

let
Expand Down
Loading

0 comments on commit 7bf0c5c

Please sign in to comment.