diff --git a/Project.toml b/Project.toml index bdb6f327b..a35854934 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "1.33" +ChainRules = "1.35.3" ChainRulesCore = "1.9" ChainRulesTestUtils = "1" DiffRules = "1.4" diff --git a/src/lib/array.jl b/src/lib/array.jl index 548159766..bbe13669d 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -564,35 +564,6 @@ end @adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A), Δ -> (convert(S, Δ),) -@adjoint function cholesky(Σ::Real) - C = cholesky(Σ) - return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),) -end - -@adjoint function cholesky(Σ::Diagonal; check = true) - C = cholesky(Σ, check = check) - return C, Δ::NamedTuple -> begin - issuccess(C) || throw(PosDefException(C.info)) - return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing - end -end - -# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra." -@adjoint function cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true) - C = cholesky(Σ, check = check) - return C, function(Δ::NamedTuple) - issuccess(C) || throw(PosDefException(C.info)) - U, Ū = C.U, Δ.factors - Σ̄ = similar(U.data) - Σ̄ = mul!(Σ̄, Ū, U') - Σ̄ = copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - Σ̄[diagind(Σ̄)] ./= 2 - return (UpperTriangular(Σ̄),) - end -end - @adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix) X = lyap(A, C) return X, function (X̄) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ac0dd28bf..182f2b666 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -654,7 +654,6 @@ end g(X) = cholesky(X * X' + I) @test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈ Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1] - @test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,)) # https://github.com/FluxML/Zygote.jl/issues/932 @test gradcheck(rand(5, 5), rand(5)) do A, x @@ -820,6 +819,32 @@ end @test back′(C̄)[1] isa Diagonal @test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1]) end + @testset "cholesky - Hermitian{Complex}" begin + rng, N = MersenneTwister(123456), 3 + A = randn(rng, Complex{Float64}, N, N) + H = Hermitian(A * A' + I) + Hmat = Matrix(H) + y, back = Zygote.pullback(cholesky, Hmat) + y′, back′ = Zygote.pullback(cholesky, H) + C̄ = (factors=randn(rng, N, N),) + @test only(back′(C̄)) isa Hermitian + # gradtest does not support complex gradients, even though the pullback exists + d = only(back(C̄)) + d′ = only(back′(C̄)) + @test (d + d')/2 ≈ d′ + end + @testset "cholesky - Hermitian{Real}" begin + rng, N = MersenneTwister(123456), 3 + A = randn(rng, N, N) + H = Hermitian(A * A' + I) + Hmat = Matrix(H) + y, back = Zygote.pullback(cholesky, Hmat) + y′, back′ = Zygote.pullback(cholesky, H) + C̄ = (factors=randn(rng, N, N),) + @test back′(C̄)[1] isa Hermitian + @test gradtest(B->cholesky(Hermitian(B)).U, Hmat) + @test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat) + end end @testset "lyap" begin