From 30c32a45c09bbced44fbfee86ba1b222adedd3da Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 14 Apr 2024 18:48:32 +0530 Subject: [PATCH 1/3] Extend copytrito! for a sparse source --- src/SparseArrays.jl | 2 +- src/sparsematrix.jl | 24 ++++++++++++++++++++++++ test/sparsematrix_ops.jl | 24 ++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/SparseArrays.jl b/src/SparseArrays.jl index ff554c8e..84690b95 100644 --- a/src/SparseArrays.jl +++ b/src/SparseArrays.jl @@ -18,7 +18,7 @@ import Base: +, -, *, \, /, &, |, xor, ==, zero, @propagate_inbounds import LinearAlgebra: mul!, ldiv!, rdiv!, cholesky, adjoint!, diag, eigen, dot, issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, isbanded, cond, diagm, factorize, ishermitian, norm, opnorm, lmul!, rmul!, tril, triu, - matprod_dest, generic_matvecmul!, generic_matmatmul! + matprod_dest, generic_matvecmul!, generic_matmatmul!, copytrito! import Base: adjoint, argmin, argmax, Array, broadcast, circshift!, complex, Complex, conj, conj!, convert, copy, copy!, copyto!, count, diff, findall, findmax, findmin, diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index b9637eda..363abff1 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -4540,3 +4540,27 @@ function _reverse!(A::SparseMatrixCSC, dims::Tuple{Integer,Integer}) dims == (1,2) || dims == (2,1) || throw(ArgumentError("invalid dimension $dims in reverse")) _reverse!(A, :) end + +function copytrito!(M::AbstractMatrix, S::AbstractSparseMatrixCSC, uplo::Char) + Base.require_one_based_indexing(M, S) + if !(uplo == 'U' || uplo == 'L') + throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got '$uplo'")) + end + m,n = size(S) + m1,n1 = size(M) + (m1 < m || n1 < n) && throw(DimensionMismatch("dest of size ($m1,$n1) should have at least the same number of rows and columns than src of size ($m,$n)")) + + rv = rowvals(S) + nz = nonzeros(S) + for col in axes(S,2) + trirange = uplo == 'U' ? (1:min(col, size(S,1))) : (col:size(S,1)) + fill!(view(M, trirange, col), zero(eltype(S))) + for i in nzrange(S, col) + row = rv[i] + (uplo == 'U' && row <= col) || (uplo == 'L' && row >= col) || continue + v = nz[i] + M[row, col] = v + end + end + return M +end diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index 829baf89..2d3562bd 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -601,4 +601,28 @@ end end end +@testset "copytrito!" begin + S = sparse([1,2,2,2,3], [1,1,2,2,4], [5, -19, 73, 12, -7]) + M = fill(Inf, size(S)) + copytrito!(M, S, 'U') + for col in axes(S, 2) + for row in 1:min(col, size(S,1)) + @test M[row, col] == S[row, col] + end + for row in min(col, size(S,1))+1:size(S,1) + @test isinf(M[row, col]) + end + end + M .= Inf + copytrito!(M, S, 'L') + for col in axes(S, 2) + for row in 1:col-1 + @test isinf(M[row, col]) + end + for row in col:size(S, 1) + @test M[row, col] == S[row, col] + end + end +end + end # module From f921966a6e9bc91d360c22d1017e97eb17bafb2e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 15 Apr 2024 11:52:31 +0530 Subject: [PATCH 2/3] Test for invalid uplo --- test/sparsematrix_ops.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index 2d3562bd..94e268e0 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -623,6 +623,7 @@ end @test M[row, col] == S[row, col] end end + @test_throws ArgumentError copytrito!(M, S, 'M') end end # module From 16076cda80cd378f0c29eff881286c1da898222b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 15 Apr 2024 18:36:52 +0530 Subject: [PATCH 3/3] Assign nz[i] directly --- src/sparsematrix.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index 363abff1..f94318b2 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -4558,8 +4558,7 @@ function copytrito!(M::AbstractMatrix, S::AbstractSparseMatrixCSC, uplo::Char) for i in nzrange(S, col) row = rv[i] (uplo == 'U' && row <= col) || (uplo == 'L' && row >= col) || continue - v = nz[i] - M[row, col] = v + M[row, col] = nz[i] end end return M