From 54995d6c82d040d25380c0d5c6c8abc659baf52e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 May 2022 14:06:48 +0200 Subject: [PATCH 1/6] Add rules for `det` and `logdet` of `Cholesky` --- Project.toml | 2 +- src/rulesets/LinearAlgebra/dense.jl | 4 ++-- src/rulesets/LinearAlgebra/factorization.jl | 21 ++++++++++++++++++++ test/rulesets/LinearAlgebra/factorization.jl | 19 ++++++++++++++++++ 4 files changed, 43 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 881f63716..68369c095 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.29.0" +version = "1.30.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index ebd0c07ce..9a8cd7981 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -118,7 +118,7 @@ end ##### `det` ##### -function frule((_, Δx), ::typeof(det), x::AbstractMatrix) +function frule((_, Δx), ::typeof(det), x::StridedMatrix{<:Number}) Ω = det(x) # TODO Performance optimization: probably there is an efficent # way to compute this trace without during the full compution within @@ -126,7 +126,7 @@ function frule((_, Δx), ::typeof(det), x::AbstractMatrix) end frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx) -function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) +function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}}) Ω = det(x) function det_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ : inv(x)' * dot(Ω, ΔΩ) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 41bc24b8c..dd2cca430 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -551,3 +551,24 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} end return getproperty(F, x), getproperty_cholesky_pullback end + +# `det` and `logdet` for `Cholesky` +function rrule(::typeof(det), C::Cholesky) + y = det(C) + s = conj!((2 * y) ./ _diag_view(C.factors)) + function det_Cholesky_pullback(ȳ) + ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) + return NoTangent(), ΔC + end + return y, det_Cholesky_pullback +end + +function rrule(::typeof(logdet), C::Cholesky) + y = logdet(C) + s = conj!((2 * one(eltype(C))) ./ _diag_view(C.factors)) + function logdet_Cholesky_pullback(ȳ) + ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) + return NoTangent(), ΔC + end + return y, logdet_Cholesky_pullback +end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 97b0cfa0f..446b3ed66 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -432,5 +432,24 @@ end ΔX_symmetric = chol_back_sym(Δ)[2] @test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2] end + + @testset "det and logdet (uplo=$p)" for p in ['U', 'L'] + @testset "$op" for op in (det, logdet) + @testset "$T" for T in (Float64, ComplexF64) + n = 5 + # rand (not randn) so det will be postive, so logdet will be defined + A = 3 * rand(T, (n, n)) + X = Cholesky((p === 'U' ? UpperTriangular : LowerTriangular)(A * A' + I)) + X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal + test_rrule(op, X ⊢ X̄_acc) + + # return type + _, op_pullback = rrule(op, X) + X̄ = op_pullback(2.7)[2] + @test X̄ isa Tangent{<:Cholesky} + @test X̄.factors isa Diagonal + end + end + end end end From ac246e70d051fea15bcbbf57d3435973ee6fd87c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 10 May 2022 16:13:20 +0200 Subject: [PATCH 2/6] Fix tests on Julia 1.6 --- test/rulesets/LinearAlgebra/factorization.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 446b3ed66..3f3f495e6 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -433,13 +433,13 @@ end @test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2] end - @testset "det and logdet (uplo=$p)" for p in ['U', 'L'] + @testset "det and logdet (uplo=$p)" for p in (:U, :L) @testset "$op" for op in (det, logdet) @testset "$T" for T in (Float64, ComplexF64) n = 5 # rand (not randn) so det will be postive, so logdet will be defined A = 3 * rand(T, (n, n)) - X = Cholesky((p === 'U' ? UpperTriangular : LowerTriangular)(A * A' + I)) + X = Cholesky(A * A' + I, p, 0) X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal test_rrule(op, X ⊢ X̄_acc) From 40e0890f7f7718c0245fb0dc740484baa980c5f2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 12 May 2022 22:08:27 +0200 Subject: [PATCH 3/6] Revert restricting `det` to `StridedMatrix` --- src/rulesets/LinearAlgebra/dense.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 9a8cd7981..ebd0c07ce 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -118,7 +118,7 @@ end ##### `det` ##### -function frule((_, Δx), ::typeof(det), x::StridedMatrix{<:Number}) +function frule((_, Δx), ::typeof(det), x::AbstractMatrix) Ω = det(x) # TODO Performance optimization: probably there is an efficent # way to compute this trace without during the full compution within @@ -126,7 +126,7 @@ function frule((_, Δx), ::typeof(det), x::StridedMatrix{<:Number}) end frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx) -function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}}) +function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) ∂x = x isa Number ? ΔΩ : inv(x)' * dot(Ω, ΔΩ) From e7e929b2d207ef180561ea7ac430dc999cfdaaf0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 12 May 2022 22:09:06 +0200 Subject: [PATCH 4/6] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 68369c095..1fff7120b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.30.0" +version = "1.31.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 1482a95e3b6adfc2886c572ea9d36ae0ac223bad Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 18 May 2022 16:21:23 +0200 Subject: [PATCH 5/6] Handle Cholesky factorizations of singular matrices --- src/rulesets/LinearAlgebra/factorization.jl | 15 +++++++++++---- test/rulesets/LinearAlgebra/factorization.jl | 11 +++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index dd2cca430..4d87e7f53 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -555,19 +555,26 @@ end # `det` and `logdet` for `Cholesky` function rrule(::typeof(det), C::Cholesky) y = det(C) - s = conj!((2 * y) ./ _diag_view(C.factors)) + diagF = _diag_view(C.factors) function det_Cholesky_pullback(ȳ) - ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) + ΔF = Diagonal(_x_divide_conj_y.(2 * ȳ * conj(y), diagF)) + ΔC = Tangent{typeof(C)}(; factors=ΔF) return NoTangent(), ΔC end return y, det_Cholesky_pullback end +# compute `x / conj(y)`, handling `x = y = 0` +function _x_divide_conj_y(x, y) + z = x / conj(y) + # in our case `iszero(x)` implies `iszero(y)` + return iszero(x) ? zero(z) : z +end function rrule(::typeof(logdet), C::Cholesky) y = logdet(C) - s = conj!((2 * one(eltype(C))) ./ _diag_view(C.factors)) + diagF = _diag_view(C.factors) function logdet_Cholesky_pullback(ȳ) - ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s)) + ΔC = Tangent{typeof(C)}(; factors=Diagonal((2 * ȳ) ./ conj.(diagF))) return NoTangent(), ΔC end return y, logdet_Cholesky_pullback diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 3f3f495e6..4f2be2e41 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -450,6 +450,17 @@ end @test X̄.factors isa Diagonal end end + + @testset "singular ($T)" for T in (Float64, ComplexF64) + n = 5 + L = LowerTriangular(randn(T, (n, n))) + L[1, 1] = zero(T) + X = cholesky(L * L'; check=false) + detX, det_pullback = rrule(det, X) + ΔX = det_pullback(rand())[2] + @test iszero(detX) + @test ΔX.factors isa Diagonal && all(iszero, ΔX.factors) + end end end end From d831cd49fc8a7669e8e10930bf3214da61ba22dc Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 May 2022 00:05:27 +0200 Subject: [PATCH 6/6] Handle zero co-tangents --- src/rulesets/LinearAlgebra/factorization.jl | 14 +++++++------- test/rulesets/LinearAlgebra/factorization.jl | 4 ++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 4d87e7f53..e47d4b2c2 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -563,19 +563,19 @@ function rrule(::typeof(det), C::Cholesky) end return y, det_Cholesky_pullback end -# compute `x / conj(y)`, handling `x = y = 0` -function _x_divide_conj_y(x, y) - z = x / conj(y) - # in our case `iszero(x)` implies `iszero(y)` - return iszero(x) ? zero(z) : z -end function rrule(::typeof(logdet), C::Cholesky) y = logdet(C) diagF = _diag_view(C.factors) function logdet_Cholesky_pullback(ȳ) - ΔC = Tangent{typeof(C)}(; factors=Diagonal((2 * ȳ) ./ conj.(diagF))) + ΔC = Tangent{typeof(C)}(; factors=Diagonal(_x_divide_conj_y.(2 * ȳ, diagF))) return NoTangent(), ΔC end return y, logdet_Cholesky_pullback end + +# Return `x / conj(y)`, or a type-stable 0 if `iszero(x)` +function _x_divide_conj_y(x, y) + z = x / conj(y) + return iszero(x) ? zero(z) : z +end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 4f2be2e41..2f6c4599a 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -448,6 +448,10 @@ end X̄ = op_pullback(2.7)[2] @test X̄ isa Tangent{<:Cholesky} @test X̄.factors isa Diagonal + + # zero co-tangent + X̄ = op_pullback(0.0)[2] + @test all(iszero, X̄.factors) end end