diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 65933b4237..efef9ab6e5 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -757,11 +757,63 @@ function EnzymeRules.reverse( return (nothing, nothing) end -# y = inv(A) B -# dY = inv(A) [ dB - dA y ] -# -> -# B(out) = inv(A) B(in) -# dB(out) = inv(A) [ dB(in) - dA B(out) ] +function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) + fact = cholesky(A.val; kwargs...) + if RT <: Const + return fact + else + N = width(RT) + + dA = if isa(A, Const) + ntuple(Val(N)) do i + Base.@_inline_meta + return zero(A.val) + end + else + N == 1 ? (A.dval,) : A.dval + end + + dfact = ntuple(Val(N)) do i + Base.@_inline_meta + return _cholesky_forward(fact, dA[i]) + end + + if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed) + return dfact + elseif RT <: Duplicated + return Duplicated(fact, dfact[1]) + else + return BatchDuplicated(fact, dfact) + end + end +end + +function _cholesky_forward(C::Cholesky, Ȧ) + # Computes the cholesky forward mode update rule + # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf + if C.uplo === 'U' + U = C.U + U̇ = Ȧ / U + ldiv!(U', U̇) + idx = diagind(U̇) + U̇[idx] ./= 2 + triu!(U̇) + rmul!(U̇, U) + U̇ .+= UpperTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle + return Cholesky(U̇, 'U', C.info) + else + L = C.L + L̇ = L \ Ȧ + rdiv!(L̇, L') + idx = diagind(L̇) + L̇[idx] ./= 2 + tril!(L̇) + lmul!(L, L̇) + L̇ .+= LowerTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle + return Cholesky(L̇, 'L', C.info) + end +end + function EnzymeRules.forward(func::Const{typeof(ldiv!)}, RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, fact::Annotation{<:Cholesky}, @@ -777,8 +829,7 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, U = fact.val.U ldiv!(L, B.val) - ntuple(Val(N)) do b - Base.@_inline_meta + for b in 1:N dB = N == 1 ? B.dval : B.dval[b] if !(fact isa Const) dL = N == 1 ? fact.dval.L : fact.dval[b].L @@ -786,16 +837,20 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end ldiv!(L, dB) end - ldiv!(U, B.val) - dretvals = ntuple(Val(N)) do b - Base.@_inline_meta + for b in 1:N dB = N == 1 ? B.dval : B.dval[b] if !(fact isa Const) dU = N == 1 ? fact.dval.U : fact.dval[b].U mul!(dB, dU, B.val, -1, 1) end ldiv!(U, dB) + end + + ldiv!(U, B.val) + dretvals = ntuple(Val(N)) do b + Base.@_inline_meta + dB = N == 1 ? B.dval : B.dval[b] return dB end @@ -812,3 +867,161 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end end + +function EnzymeRules.augmented_primal(config, + func::Const{typeof(cholesky)}, + RT::Type, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSymComplexHerm}}; + kwargs...) + fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) + cholesky(A.val; kwargs...) + else + nothing + end + + fact_returned = EnzymeRules.needs_primal(config) ? fact : nothing + + # dfact would be a dense matrix, prepare buffer + dfact = if RT <: Const + nothing + else + if EnzymeRules.width(config) == 1 + Enzyme.make_zero(fact) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + return Enzyme.make_zero(fact) + end + end + end + + cache = isa(A, Const) ? nothing : (fact, dfact) + return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache) +end + +function EnzymeRules.reverse(config, + ::Const{typeof(cholesky)}, + RT::Type, + cache, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSymComplexHerm}}; + kwargs...) + if !(RT <: Const) && !isa(A, Const) + fact, dfact = cache + dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval + dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact + + for (dA, dfact) in zip(dAs, dfacts) + _dA = dA isa LinearAlgebra.RealHermSymComplexHerm ? dA.data : dA + if _dA !== dfact.factors + Ā = _cholesky_pullback_shared_code(fact, dfact) + _dA .+= Ā + dfact.factors .= 0 + end + end + end + return (nothing,) +end + +# Adapted from ChainRules.jl +# MIT "Expat" License +# Copyright (c) 2018: Jarrett Revels. +# https://github.com/JuliaDiff/ChainRules.jl/blob/9f1817a22404259113e230bef149a54d379a660b/src/rulesets/LinearAlgebra/factorization.jl#L507-L528 +function _cholesky_pullback_shared_code(C, ΔC) + Δfactors = ΔC.factors + Ā = similar(C.factors) + if C.uplo === 'U' + U = C.U + Ū = ΔC.U + Ū = eltype(U) <: Real ? real(UpperTriangular(Δfactors)) : UpperTriangular(Δfactors) + mul!(Ā, Ū, U') + LinearAlgebra.copytri!(Ā, 'U', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + ldiv!(U, Ā) + rdiv!(Ā, U') + Ā .+= tril!(ΔC.factors, -1)' # correction for unused triangle + triu!(Ā) + else # C.uplo === 'L' + L = C.L + L̄ = ΔC.L + L̄ = eltype(L) <: Real ? real(LowerTriangular(Δfactors)) : LowerTriangular(Δfactors) + mul!(Ā, L', L̄) + LinearAlgebra.copytri!(Ā, 'L', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + rdiv!(Ā, L) + ldiv!(L', Ā) + Ā .+= triu!(ΔC.factors, 1)' # correction for unused triangle + tril!(Ā) + end + idx = diagind(Ā) + @views Ā[idx] .= real.(Ā[idx]) ./ 2 + return Ā +end + +function _realifydiag!(A) + for i in diagind(A) + @inbounds A[i] = real(A[i]) + end + return A +end + +function EnzymeRules.augmented_primal(config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed, + BatchDuplicated}}, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) + cache_B = if !isa(A, Const) && !isa(B, Const) + copy(B.val) + else + nothing + end + + cache_A = if !isa(B, Const) + EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val + else + nothing + end + + primal = EnzymeRules.needs_primal(config) ? B.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? B.dval : nothing + func.val(A.val, B.val; kwargs...) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) +end + +function EnzymeRules.reverse(config, + func::Const{typeof(ldiv!)}, + dret, + cache, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) + if !isa(B, Const) + (cache_A, cache_B) = cache + U = cache_A.U + Z = isa(A, Const) ? nothing : U' \ cache_B + Y = isa(A, Const) ? nothing : U \ Z + for b in 1:EnzymeRules.width(config) + dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] + dZ = U' \ dB + func.val(cache_A, dB; kwargs...) + if !isa(A, Const) + ∂B = U \ dZ + Ā = -dZ * Y' - Z * ∂B' + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + if A.val.uplo === 'U' + dA.factors .+= UpperTriangular(Ā) + else + dA.factors .+= LowerTriangular(Ā') + end + + end + end + end + return (nothing, nothing) +end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7cc5c07321..de492c37c8 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -270,299 +270,69 @@ end end @static if VERSION > v"1.8" -@testset "Cholesky" begin - function symmetric_definite(n :: Int=10) - α = one(Float64) - A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1)) - b = A * Float64[1:n;] - return A, b - end - - function divdriver_NC(x, fact, b) - res = fact\b - x .= res - return nothing - end - - function ldivdriver_NC(x, fact, b) - ldiv!(fact,b) - x .= b - return nothing - end - - function divdriver(x, A, b) - fact = cholesky(A) - divdriver_NC(x, fact, b) - end - - function divdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - divdriver_NC(x, fact, b) - end - - function divdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - divdriver_NC(x, fact, b) - end - - function ldivdriver(x, A, b) - fact = cholesky(A) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - ldivdriver_NC(x, fact, b) - end - - # Test forward - function fwdJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) - adJ[i, :] = dx.dval - end - return adJ - end - - function const_fwdJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - Const(A), - db - ) - adJ[i, :] = dx.dval - end - return adJ - end - - function batchedfwdJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> seed(i), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) - for i in 1:n - adJ[i, :] = dx.dval[i] - end - return adJ - end - - # Test reverse - function revJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) - adJ[i, :] = db.dval - end - return adJ - end - - function const_revJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) - dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - Const(A), - db - ) - adJ[i, :] = db.dval + @testset "cholesky" begin + activities = (Const, Duplicated, BatchDuplicated) + function _square(A) + S = A * adjoint(A) + S[diagind(S)] .= real.(S[diagind(S)]) # workaround for issue #1456: + return S end - return adJ - end - - function batchedrevJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) - for i in 1:n - adJ[i, :] .= db.dval[i] - end - return adJ - end - - function Jdxdb(driver, A, b) - x = A\b - dA = zeros(size(A)) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx - end - return J - end - - function JdxdA(driver, A, b) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx + @testset for (Te, TSs) in ( + Float64 => (Symmetric, Hermitian), + ComplexF64 => (Hermitian,), + ), TA in activities, Tret in activities + @testset "without wrapper arguments" begin + A = rand(Te, 5, 5) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) + end + @testset "with wrapper arguments" for TS in TSs, uplo in (:U, :L) + _A = collect(exp(TS(I + rand(Te, 5, 5)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA); fdm=FiniteDifferences.forward_fdm(5, 1)) + test_reverse(cholesky, Tret, (A, TA)) + end end - return J end - - @testset "Testing $op" for (op, driver, driver_NC) in ( - (:\, divdriver, divdriver_NC), - (:\, divdriver_herm, divdriver_NC), - (:\, divdriver_sym, divdriver_NC), - (:ldiv!, ldivdriver, ldivdriver_NC), - (:ldiv!, ldivdriver_herm, ldivdriver_NC), - (:ldiv!, ldivdriver_sym, ldivdriver_NC) - ) - A, b = symmetric_definite(10) - n = length(b) - A = Matrix(A) - x = zeros(n) - x = driver(x, A, b) - fdm = forward_fdm(2, 1); - - function b_one(b) - _x = zeros(length(b)) - driver(_x,A,b) - return _x - end - - fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing - fwdJ = fwdJdxdb(driver, A, b) - revJ = revJdxdb(driver, A, b) - batchedrevJ = batchedrevJdxdb(driver, A, b) - batchedfwdJ = batchedfwdJdxdb(driver, A, b) - J = Jdxdb(driver, A, b) - - if op == :\ - @test isapprox(fwdJ, fdJ) - end - @test isapprox(fwdJ, revJ) - @test isapprox(fwdJ, batchedrevJ) - @test isapprox(fwdJ, batchedfwdJ) - - fwdJ = const_fwdJdxdb(driver_NC, cholesky(A), b) - revJ = const_revJdxdb(driver_NC, cholesky(A), b) - if op == :\ - @test isapprox(fwdJ, fdJ) - end - @test isapprox(fwdJ, revJ) - - function h(A, b) - A = copy(A) - LinearAlgebra.LAPACK.potrf!('U', A) - b2 = copy(b) - LinearAlgebra.LAPACK.potrs!('U', A, b2) - @inbounds b2[1] + @testset "Linear solve for `Cholesky`" begin + activities = (Const, Duplicated, DuplicatedNoNeed, BatchDuplicated, + BatchDuplicatedNoNeed) + @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') + C = Cholesky(I + rand(Te, 5, 5), uplo, 0) # add `I` for numerical stability + B = rand(Te, 5, 5) + b = rand(Te, 5) + @testset for TC in activities, + TB in activities, + Tret in (Const, Duplicated, BatchDuplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 + Tret == TC == TB || continue + test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) + end + end + @testset for TC in activities, + TB in activities, + Tret in (Const, Duplicated, BatchDuplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 + Tret == TC == TB || continue + test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + end + end end - - A = [1.3 0.5; 0.5 1.5] - b = [1., 2.] - dA = zero(A) - Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) - # dA_fwd = Enzyme.gradient(Forward, A->h(A, b), A) - dA_fd = FiniteDifferences.grad(central_fdm(5, 1), A->h(A, b), A)[1] - - @test isapprox(dA, dA_fd) end -end - -function chol_upper(x) - x = reshape(x, 4, 4) - x = parent(cholesky(Hermitian(x)).U) - x = convert(typeof(x), UpperTriangular(x)) - return x[1,2] -end - -@testset "Cholesky upper triangular v1" begin - x = [1.0, -0.10541615131279458, 0.6219810761363638, 0.293343219811946, -0.10541615131279458, 1.0, -0.05258941747718969, 0.34629296878264443, 0.6219810761363638, -0.05258941747718969, 1.0, 0.4692436399208845, 0.293343219811946, 0.34629296878264443, 0.4692436399208845, 1.0] - @test collect(Enzyme.gradient(Forward, chol_upper, x)) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - - @test Enzyme.gradient(Reverse, chol_upper, x) ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -end - @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))