diff --git a/src/SparseArrays.jl b/src/SparseArrays.jl index ff554c8e..84690b95 100644 --- a/src/SparseArrays.jl +++ b/src/SparseArrays.jl @@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot, issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded, cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu, - matprod_dest, generic_matvecmul!, generic_matmatmul! + matprod_dest, generic_matvecmul!, generic_matmatmul!, copytrito! import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex, conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin, diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index b9637eda..f94318b2 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -4540,3 +4540,26 @@ function _reverse!(A::SparseMatrixCSC, dims::Tuple{Integer,Integer}) dims == (1,2) || dims == (2,1) || throw(ArgumentError("invalid dimension $dims in reverse")) _reverse!(A, :) end + +function copytrito!(M::AbstractMatrix, S::AbstractSparseMatrixCSC, uplo::Char) + Base.require_one_based_indexing(M, S) + if !(uplo == 'U' || uplo == 'L') + throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got '$uplo'")) + end + m,n = size(S) + m1,n1 = size(M) + (m1 < m || n1 < n) && throw(DimensionMismatch("dest of size ($m1,$n1) should have at least the same number of rows and columns than src of size ($m,$n)")) + + rv = rowvals(S) + nz = nonzeros(S) + for col in axes(S,2) + trirange = uplo == 'U' ? (1:min(col, size(S,1))) : (col:size(S,1)) + fill!(view(M, trirange, col), zero(eltype(S))) + for i in nzrange(S, col) + row = rv[i] + (uplo == 'U' && row <= col) || (uplo == 'L' && row >= col) || continue + M[row, col] = nz[i] + end + end + return M +end diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index 829baf89..94e268e0 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -601,4 +601,29 @@ end end end +@testset "copytrito!" begin + S = sparse([1,2,2,2,3], [1,1,2,2,4], [5, -19, 73, 12, -7]) + M = fill(Inf, size(S)) + copytrito!(M, S, 'U') + for col in axes(S, 2) + for row in 1:min(col, size(S,1)) + @test M[row, col] == S[row, col] + end + for row in min(col, size(S,1))+1:size(S,1) + @test isinf(M[row, col]) + end + end + M .= Inf + copytrito!(M, S, 'L') + for col in axes(S, 2) + for row in 1:col-1 + @test isinf(M[row, col]) + end + for row in col:size(S, 1) + @test M[row, col] == S[row, col] + end + end + @test_throws ArgumentError copytrito!(M, S, 'M') +end + end # module