Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copy in chol_lower #144

Merged
merged 13 commits into from
Nov 4, 2021
10 changes: 5 additions & 5 deletions src/chol.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
CholType{T,S<:AbstractMatrix} = Cholesky{T,S}
st-- marked this conversation as resolved.
Show resolved Hide resolved
chol_lower(a::Matrix) = cholesky(a).L
# Accessing a.L directly might involve an extra copy();
# instead, always use the stored Cholesky factor:
chol_lower(a::Cholesky) = a.uplo === 'L' ? a.L : a.U'
chol_upper(a::Cholesky) = a.uplo === 'U' ? a.U : a.L'

# always use the stored cholesky factor, not a copy
chol_lower(a::CholType) = a.uplo === 'L' ? a.L : a.U'
chol_upper(a::CholType) = a.uplo === 'U' ? a.U : a.L'
chol_lower(a::Matrix) = chol_lower(cholesky(a))

if HAVE_CHOLMOD
CholTypeSparse{T} = SuiteSparse.CHOLMOD.Factor{T}
Expand Down
20 changes: 20 additions & 0 deletions test/chol.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using LinearAlgebra, PDMats
using PDMats: chol_lower

@testset "chol_lower" begin
A = rand(100, 100)
C = A'A
size_of_one_copy = sizeof(C)
@assert size_of_one_copy > 100 # ensure the matrix is large enough that few-byte allocations don't matter

chol_lower(C)
@test (@allocated chol_lower(C)) < 1.05 * size_of_one_copy # allow 5% overhead

for uplo in (:L, :U)
ch = cholesky(Symmetric(C, uplo))
chol_lower(ch)
@test (@allocated chol_lower(ch)) < 50 # allow small overhead for wrapper types
chol_upper(ch)
@test (@allocated chol_upper(ch)) < 50 # allow small overhead for wrapper types
@test
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
include("testutils.jl")
tests = ["pdmtypes", "addition", "generics", "kron"]
tests = ["pdmtypes", "addition", "generics", "kron", "chol"]
println("Running tests ...")

for t in tests
Expand Down