diff --git a/Project.toml b/Project.toml index 392195ecd..213882544 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.32.1" +version = "1.33.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 41bc24b8c..e47d4b2c2 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -551,3 +551,31 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky} end return getproperty(F, x), getproperty_cholesky_pullback end + +# `det` and `logdet` for `Cholesky` +function rrule(::typeof(det), C::Cholesky) + y = det(C) + diagF = _diag_view(C.factors) + function det_Cholesky_pullback(ȳ) + ΔF = Diagonal(_x_divide_conj_y.(2 * ȳ * conj(y), diagF)) + ΔC = Tangent{typeof(C)}(; factors=ΔF) + return NoTangent(), ΔC + end + return y, det_Cholesky_pullback +end + +function rrule(::typeof(logdet), C::Cholesky) + y = logdet(C) + diagF = _diag_view(C.factors) + function logdet_Cholesky_pullback(ȳ) + ΔC = Tangent{typeof(C)}(; factors=Diagonal(_x_divide_conj_y.(2 * ȳ, diagF))) + return NoTangent(), ΔC + end + return y, logdet_Cholesky_pullback +end + +# Return `x / conj(y)`, or a type-stable 0 if `iszero(x)` +function _x_divide_conj_y(x, y) + z = x / conj(y) + return iszero(x) ? zero(z) : z +end diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 97b0cfa0f..2f6c4599a 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -432,5 +432,39 @@ end ΔX_symmetric = chol_back_sym(Δ)[2] @test sym_back(ΔX_symmetric)[2] ≈ dX_pullback(Δ)[2] end + + @testset "det and logdet (uplo=$p)" for p in (:U, :L) + @testset "$op" for op in (det, logdet) + @testset "$T" for T in (Float64, ComplexF64) + n = 5 + # rand (not randn) so det will be postive, so logdet will be defined + A = 3 * rand(T, (n, n)) + X = Cholesky(A * A' + I, p, 0) + X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal + test_rrule(op, X ⊢ X̄_acc) + + # return type + _, op_pullback = rrule(op, X) + X̄ = op_pullback(2.7)[2] + @test X̄ isa Tangent{<:Cholesky} + @test X̄.factors isa Diagonal + + # zero co-tangent + X̄ = op_pullback(0.0)[2] + @test all(iszero, X̄.factors) + end + end + + @testset "singular ($T)" for T in (Float64, ComplexF64) + n = 5 + L = LowerTriangular(randn(T, (n, n))) + L[1, 1] = zero(T) + X = cholesky(L * L'; check=false) + detX, det_pullback = rrule(det, X) + ΔX = det_pullback(rand())[2] + @test iszero(detX) + @test ΔX.factors isa Diagonal && all(iszero, ΔX.factors) + end + end end end