From 064776aeadaf8afa424ae18b24f7f9f6d8cd00de Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 17 Jun 2022 22:30:11 +0200 Subject: [PATCH 1/4] Add rules for division by Cholesky --- src/rulesets/LinearAlgebra/factorization.jl | 35 +++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index f2f8d32d1..9bba0c155 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -578,3 +578,38 @@ function _x_divide_conj_y(x, y) z = x / conj(y) return iszero(x) ? zero(z) : z end + +# these rules exists because the primals mutates using `ldiv!` and `rdiv!` +function rrule(::typeof(\), A::Cholesky, B::AbstractVecOrMat{<:Union{Real,Complex}}) + U, getproperty_back = rrule(getproperty, A, :U) + Z = U' \ B + Y = U \ Z + project_B = ProjectTo(B) + function ldiv_Cholesky_AbsVecOrMat_pullback(ΔY) + ∂Z = U' \ ΔY + ∂B = U \ ∂Z + ∂A = Thunk() do + _, Ā = getproperty_back(-add!!(∂Z * Y', Z * ∂B')) + return Ā + end + return NoTangent(), ∂A, project_B(∂B) + end + return Y, ldiv_Cholesky_AbsVecOrMat_pullback +end + +function rrule(::typeof(/), B::AbstractMatrix{<:Union{Real,Complex}}, A::Cholesky) + U, getproperty_back = rrule(getproperty, A, :U) + Z = B / U + Y = Z / U' + project_B = ProjectTo(B) + function rdiv_AbstractMatrix_Cholesky_pullback(ΔY) + ∂Z = ΔY / U + ∂B = ∂Z / U' + ∂A = Thunk() do + _, Ā = getproperty_back(-add!!(∂Z' * Y, Z' * ∂B)) + return Ā + end + return NoTangent(), project_B(∂B), ∂A + end + return Y, rdiv_AbstractMatrix_Cholesky_pullback +end From c9c8485ece4649987ed5ef57e53be5e268e052c3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 17 Jun 2022 22:30:19 +0200 Subject: [PATCH 2/4] Add tests for new rules --- test/rulesets/LinearAlgebra/factorization.jl | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 2973ba893..17d034946 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -521,5 +521,29 @@ end @test ΔX.factors isa Diagonal && all(iszero, ΔX.factors) end end + + @testset "\\(::Cholesky, ::AbstractVecOrMat)" begin + n = 10 + for T in (Float64, ComplexF64), sz in (n, (n, 5)) + A = generate_well_conditioned_matrix(T, n) + C = cholesky(A) + B = randn(T, sz) + # because the rule calls the rrule for getproperty, its rrule is not + # completely type-inferrable + test_rrule(\, C, B; check_inferred=false) + end + end + + @testset "/(::AbstractMatrix, ::Cholesky)" begin + n = 10 + for T in (Float64, ComplexF64) + A = generate_well_conditioned_matrix(T, n) + C = cholesky(A) + B = randn(T, 5, n) + # because the rule calls the rrule for getproperty, its rrule is not + # completely type-inferrable + test_rrule(/, B, C; check_inferred=false) + end + end end end From 630c1cbedfab4c8b47b09ff2244d0f36d11c49ff Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 17 Jun 2022 22:30:51 +0200 Subject: [PATCH 3/4] Increment minor version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2433a21c1..c688ce857 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.35.3" +version = "1.36.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From d5ac0a7491a86efcece6b475d94be1681759de44 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 3 Aug 2022 12:59:47 +0200 Subject: [PATCH 4/4] Increment minor version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2c22e3a0b..0c208ae89 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.40.0" +version = "1.41.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"