Skip to content

Commit

Permalink
Merge pull request #1114 from st--/st/remove_cholesky_adjoint
Browse files Browse the repository at this point in the history
remove `@adjoint function cholesky`
  • Loading branch information
ToucheSir authored Jun 20, 2022
2 parents a4d0ad4 + f7203ff commit c4b4fa9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.33"
ChainRules = "1.35.3"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
Expand Down
29 changes: 0 additions & 29 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,35 +564,6 @@ end
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)

@adjoint function cholesky::Real)
C = cholesky(Σ)
return C, Δ::NamedTuple->.factors[1, 1] / (2 * C.U[1, 1]),)
end

@adjoint function cholesky::Diagonal; check = true)
C = cholesky(Σ, check = check)
return C, Δ::NamedTuple -> begin
issuccess(C) || throw(PosDefException(C.info))
return Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing
end
end

# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
@adjoint function cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true)
C = cholesky(Σ, check = check)
return C, function::NamedTuple)
issuccess(C) || throw(PosDefException(C.info))
U, Ū = C.U, Δ.factors
Σ̄ = similar(U.data)
Σ̄ = mul!(Σ̄, Ū, U')
Σ̄ = copytri!(Σ̄, 'U')
Σ̄ = ldiv!(U, Σ̄)
Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
Σ̄[diagind(Σ̄)] ./= 2
return (UpperTriangular(Σ̄),)
end
end

@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
return X, function (X̄)
Expand Down
27 changes: 26 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ end
g(X) = cholesky(X * X' + I)
@test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1]
Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1]
@test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,))

# https://github.com/FluxML/Zygote.jl/issues/932
@test gradcheck(rand(5, 5), rand(5)) do A, x
Expand Down Expand Up @@ -820,6 +819,32 @@ end
@test back′(C̄)[1] isa Diagonal
@test diag(back′(C̄)[1]) diag(back(C̄)[1])
end
@testset "cholesky - Hermitian{Complex}" begin
rng, N = MersenneTwister(123456), 3
A = randn(rng, Complex{Float64}, N, N)
H = Hermitian(A * A' + I)
Hmat = Matrix(H)
y, back = Zygote.pullback(cholesky, Hmat)
y′, back′ = Zygote.pullback(cholesky, H)
= (factors=randn(rng, N, N),)
@test only(back′(C̄)) isa Hermitian
# gradtest does not support complex gradients, even though the pullback exists
d = only(back(C̄))
d′ = only(back′(C̄))
@test (d + d')/2 d′
end
@testset "cholesky - Hermitian{Real}" begin
rng, N = MersenneTwister(123456), 3
A = randn(rng, N, N)
H = Hermitian(A * A' + I)
Hmat = Matrix(H)
y, back = Zygote.pullback(cholesky, Hmat)
y′, back′ = Zygote.pullback(cholesky, H)
= (factors=randn(rng, N, N),)
@test back′(C̄)[1] isa Hermitian
@test gradtest(B->cholesky(Hermitian(B)).U, Hmat)
@test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat)
end
end

@testset "lyap" begin
Expand Down

0 comments on commit c4b4fa9

Please sign in to comment.