Skip to content

Commit

Permalink
added diagonal-sparse multiplication (#564)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
  • Loading branch information
matbesancon and dkarrasch authored Oct 29, 2024
1 parent 8f02b7f commit 33491e0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,23 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B
C
end

function *(A::Diagonal, b::AbstractSparseVector)
if size(A, 2) != length(b)
throw(
DimensionMismatch(lazy"The dimension of the matrix A $(size(A)) and of the vector b $(length(b))")
)
end
T = promote_eltype(A, b)
res = similar(b, T)
nzind_b = nonzeroinds(b)
nzval_b = nonzeros(b)
nzval_res = nonzeros(res)
for idx in eachindex(nzind_b)
nzval_res[idx] = A.diag[nzind_b[idx]] * nzval_b[idx]
end
return res
end

# Sparse matrix multiplication as described in [Gustavson, 1978]:
# http://dl.acm.org/citation.cfm?id=355796

Expand Down
15 changes: 15 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,21 @@ end
end
end

@testset "diagonal - sparse vector mutliplication" begin
for _ in 1:10
b = spzeros(10)
b[1:3] .= 1:3
A = Diagonal(randn(10))
@test norm(A * b - A * Vector(b)) <= 10eps()
@test norm(A * b - Array(A) * b) <= 10eps()
Ac = Diagonal(randn(Complex{Float64}, 10))
@test norm(Ac * b - Ac * Vector(b)) <= 10eps()
@test norm(Ac * b - Array(Ac) * b) <= 10eps()
@test_throws DimensionMismatch A * [b; 1]
@test_throws DimensionMismatch A * b[1:end-1]
end
end

@testset "sparse matrix * BitArray" begin
A = sprand(5,5,0.3)
MA = Array(A)
Expand Down

0 comments on commit 33491e0

Please sign in to comment.