Skip to content

Commit

Permalink
Extend copytrito! for a sparse source (#533)
Browse files Browse the repository at this point in the history
* Extend copytrito! for a sparse source

* Test for invalid uplo

* Assign nz[i] directly
  • Loading branch information
jishnub authored Apr 15, 2024
1 parent 33fbc75 commit 4606755
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds
import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot,
issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded,
cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu,
matprod_dest, generic_matvecmul!, generic_matmatmul!
matprod_dest, generic_matvecmul!, generic_matmatmul!, copytrito!

import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex,
conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin,
Expand Down
23 changes: 23 additions & 0 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4540,3 +4540,26 @@ function _reverse!(A::SparseMatrixCSC, dims::Tuple{Integer,Integer})
dims == (1,2) || dims == (2,1) || throw(ArgumentError("invalid dimension $dims in reverse"))
_reverse!(A, :)
end

function copytrito!(M::AbstractMatrix, S::AbstractSparseMatrixCSC, uplo::Char)
Base.require_one_based_indexing(M, S)
if !(uplo == 'U' || uplo == 'L')
throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got '$uplo'"))
end
m,n = size(S)
m1,n1 = size(M)
(m1 < m || n1 < n) && throw(DimensionMismatch("dest of size ($m1,$n1) should have at least the same number of rows and columns than src of size ($m,$n)"))

rv = rowvals(S)
nz = nonzeros(S)
for col in axes(S,2)
trirange = uplo == 'U' ? (1:min(col, size(S,1))) : (col:size(S,1))
fill!(view(M, trirange, col), zero(eltype(S)))
for i in nzrange(S, col)
row = rv[i]
(uplo == 'U' && row <= col) || (uplo == 'L' && row >= col) || continue
M[row, col] = nz[i]
end
end
return M
end
25 changes: 25 additions & 0 deletions test/sparsematrix_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,4 +601,29 @@ end
end
end

@testset "copytrito!" begin
S = sparse([1,2,2,2,3], [1,1,2,2,4], [5, -19, 73, 12, -7])
M = fill(Inf, size(S))
copytrito!(M, S, 'U')
for col in axes(S, 2)
for row in 1:min(col, size(S,1))
@test M[row, col] == S[row, col]
end
for row in min(col, size(S,1))+1:size(S,1)
@test isinf(M[row, col])
end
end
M .= Inf
copytrito!(M, S, 'L')
for col in axes(S, 2)
for row in 1:col-1
@test isinf(M[row, col])
end
for row in col:size(S, 1)
@test M[row, col] == S[row, col]
end
end
@test_throws ArgumentError copytrito!(M, S, 'M')
end

end # module

0 comments on commit 4606755

Please sign in to comment.