Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copy in chol_lower #144

Merged
merged 13 commits into from
Nov 4, 2021
8 changes: 5 additions & 3 deletions src/chol.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
chol_lower(a::Matrix) = cholesky(a).L

# always use the stored cholesky factor, not a copy
# Accessing a.L directly might involve an extra copy();
# instead, always use the stored Cholesky factor:
chol_lower(a::Cholesky) = a.uplo === 'L' ? a.L : a.U'
chol_upper(a::Cholesky) = a.uplo === 'U' ? a.U : a.L'

# For a dense Matrix, the following allows us to avoid the Adjoint wrapper:
chol_lower(a::Matrix) = cholesky(Hermitian(a, :L)).L

if HAVE_CHOLMOD
CholTypeSparse{T} = SuiteSparse.CHOLMOD.Factor{T}

Expand Down
20 changes: 20 additions & 0 deletions test/chol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using LinearAlgebra, PDMats
using PDMats: chol_lower, chol_upper

@testset "chol_lower" begin
A = rand(100, 100)
C = A'A
size_of_one_copy = sizeof(C)
@assert size_of_one_copy > 100 # ensure the matrix is large enough that few-byte allocations don't matter

chol_lower(C)
@test (@allocated chol_lower(C)) < 1.05 * size_of_one_copy # allow 5% overhead

for uplo in (:L, :U)
ch = cholesky(Symmetric(C, uplo))
chol_lower(ch)
@test (@allocated chol_lower(ch)) < 50 # allow small overhead for wrapper types
chol_upper(ch)
@test (@allocated chol_upper(ch)) < 50 # allow small overhead for wrapper types
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
include("testutils.jl")
tests = ["pdmtypes", "addition", "generics", "kron"]
tests = ["pdmtypes", "addition", "generics", "kron", "chol"]
println("Running tests ...")

for t in tests
Expand Down