Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rules for cholesky and ldiv! on Cholesky #1307

Closed
wants to merge 77 commits into from
Closed
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
fa2c92e
Add regression test
simsurace Feb 24, 2024
fb47774
Improve test
simsurace Feb 25, 2024
cc6fdb6
Fix regression test
simsurace Feb 25, 2024
bf6461e
Fix reverse rules and tests
simsurace Feb 26, 2024
885a38f
Fix typo
simsurace Feb 26, 2024
f53d1de
Fix all tests
simsurace Feb 26, 2024
af3c750
Simplify helper functions
simsurace Feb 26, 2024
50049ee
Simplify rule
simsurace Feb 26, 2024
cde0c43
Apply suggestions from review
simsurace Feb 27, 2024
62d2278
Fix typo
simsurace Feb 27, 2024
3d805f3
Change testset name and add explanations
simsurace Feb 27, 2024
c012073
Add unit tests for forward rule
simsurace Feb 27, 2024
7ec8316
Fix tests
simsurace Feb 27, 2024
3265429
Fix forward rule
simsurace Feb 27, 2024
861bde4
Merge branch 'main' into fix-cholesky
simsurace Feb 28, 2024
434707a
Merge branch 'main' into fix-cholesky
wsmoses Feb 29, 2024
b6a5fe8
Fix cholesky reverse rule
simsurace Mar 2, 2024
91ee976
Remove wrong branch
simsurace Mar 5, 2024
c531aae
Merge branch 'main' into fix-cholesky
simsurace Mar 5, 2024
61cf2b3
Simplify Ubar definition
simsurace Mar 6, 2024
eaee9c5
Simplify Lbar definition
simsurace Mar 6, 2024
7780fde
Remove redundant line
simsurace Mar 6, 2024
2bd2458
Remove _maybeUpperTri
simsurace Mar 6, 2024
bc78219
Update internal_rules.jl
simsurace Mar 6, 2024
9d6d45d
Improve notation
simsurace Mar 6, 2024
ea2ea2c
Improve notation
simsurace Mar 6, 2024
8bfb459
Fix wrong argument name
simsurace Mar 6, 2024
33e48af
Add copyright notice
simsurace Mar 6, 2024
c259980
Implement changes suggested in review
simsurace Mar 7, 2024
d7c04d0
Merge branch 'main' into fix-cholesky
simsurace Mar 16, 2024
cf5e573
Rewrite tests from scratch - real case
simsurace Mar 17, 2024
3aba370
Complete tests
simsurace Mar 17, 2024
ce8062f
Merge branch 'main' into fix-cholesky
simsurace Mar 19, 2024
1889491
Run formatter
simsurace Mar 20, 2024
ba171eb
Merge branch 'main' into fix-cholesky
simsurace Apr 11, 2024
2d0b22e
Merge branch 'main' into fix-cholesky
simsurace Apr 20, 2024
2977be4
Fix typo
simsurace Apr 21, 2024
886f4d8
Merge branch 'main' into fix-cholesky
simsurace Apr 21, 2024
ae2a033
Add tests suggested in review
simsurace Apr 27, 2024
3c1784e
Fix rrule
simsurace Apr 27, 2024
b4c17d2
Fix forward rules
simsurace Apr 27, 2024
7e51733
Increase test coverage, remove old tests
simsurace Apr 27, 2024
ea098be
Fix additional tests
simsurace Apr 27, 2024
4066fd1
Try to fix positive definiteness issues in CI
simsurace Apr 28, 2024
22c533a
Revert "Run formatter"
simsurace Apr 28, 2024
f19f4b1
Merge commit '1e27530c10989926c45377e1efd47f047415603e' into fix-chol…
simsurace Apr 28, 2024
c348c61
Change forward rule
simsurace Apr 30, 2024
74e2183
Fix `Duplicated` case
simsurace Apr 30, 2024
e5f85e0
Slightly refactor forward rule
simsurace May 1, 2024
1e5481d
Disable questionable tests, fix reverse rule for `ldiv!`
simsurace May 2, 2024
3f8ee74
Run formatter
simsurace May 2, 2024
4f216cc
Revert "Run formatter"
simsurace May 3, 2024
37fde99
Format changes
simsurace May 3, 2024
0b15bec
Format other lines
simsurace May 3, 2024
6637b0a
Format remaining lines
simsurace May 3, 2024
5983d9e
Generalize `ldiv!` rule to `uplo = :L`
simsurace May 3, 2024
f7d04c2
Link to issue about open question
simsurace May 3, 2024
b0c9d6a
Generalize tests
simsurace May 3, 2024
308b626
Refactor rule
simsurace May 4, 2024
20877a0
Add tests for ` BatchDuplicated`
simsurace May 4, 2024
3171823
Include all activities
simsurace May 4, 2024
781feab
Merge branch 'main' into fix-cholesky
simsurace May 4, 2024
82c2491
Merge branch 'main' into fix-cholesky
wsmoses May 10, 2024
5ea55f2
Merge branch 'main' into fix-cholesky
sethaxen May 15, 2024
8be3391
Include more activities
simsurace May 18, 2024
b31ed35
Fix typo
simsurace May 18, 2024
5bf0057
Write testsets in a different way
simsurace May 18, 2024
f885996
Merge branch 'main' into fix-cholesky
simsurace May 18, 2024
250e24f
Test complex element type
simsurace May 20, 2024
62dbbc2
Add comment regarding `I` in test
simsurace May 20, 2024
75c0766
Work around issue #1456
simsurace May 20, 2024
4d02587
Increase coverage of rules
simsurace May 20, 2024
f5c7abf
SImplify test
simsurace May 20, 2024
510c7dd
Fix test
simsurace May 20, 2024
71fc359
Merge branch 'main' into fix-cholesky
wsmoses Jun 30, 2024
3c57c1a
Merge commit 'c799584d85afeff75eb304ea57583d5fd97de98b' into fix-chol…
simsurace Jul 8, 2024
b2e922a
Merge branch 'main' into fix-cholesky
simsurace Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 157 additions & 136 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -748,26 +748,18 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...)
else
N = width(RT)

invL = inv(fact.L)

dA = if isa(A, Const)
ntuple(Val(N)) do i
Base.@_inline_meta
zeros(A.val)
return zero(A.val)
end
else
if N == 1
(A.dval,)
else
A.dval
end
N == 1 ? (A.dval,) : A.dval
end

dfact = ntuple(Val(N)) do i
Base.@_inline_meta
Cholesky(
Matrix(fact.L * LowerTriangular(invL * dA[i] * invL' * 0.5 * I)), 'L', 0
)
return _cholesky_forward(fact, dA[i])
end

if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed)
Expand All @@ -780,54 +772,69 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...)
end
end

# y = inv(A) B
# dY = inv(A) [ dB - dA y ]
# ->
# B(out) = inv(A) B(in)
# dB(out) = inv(A) [ dB(in) - dA B(out) ]
function EnzymeRules.forward(
func::Const{typeof(ldiv!)},
RT::Type,
fact::Annotation{<:Cholesky},
B;
kwargs...
)
if isa(B, Const)
@assert (RT <: Const)
function _cholesky_forward(C::Cholesky, Ȧ)
# Computes the cholesky forward mode update rule
# C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf
if C.uplo === 'U'
U = C.U
U̇ = Ȧ / U
ldiv!(U', U̇)
idx = diagind(U̇)
U̇[idx] ./= 2
triu!(U̇)
rmul!(U̇, U)
U̇ .+= UpperTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle
return Cholesky(U̇, 'U', C.info)
else
L = C.L
L̇ = L \ Ȧ
rdiv!(L̇, L')
idx = diagind(L̇)
L̇[idx] ./= 2
tril!(L̇)
lmul!(L, L̇)
L̇ .+= LowerTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle
return Cholesky(L̇, 'L', C.info)
end
end

function EnzymeRules.forward(func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const,Duplicated,BatchDuplicated}},
fact::Annotation{<:Cholesky},
B::Annotation{<:AbstractVecOrMat};
kwargs...)
if B isa Const
return func.val(fact.val, B.val; kwargs...)
else
N = width(B)
retval = B.val

@assert !isa(B, Const)
L = fact.val.L
U = fact.val.U

retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated)
func.val(fact.val, B.val; kwargs...)
else
nothing
ldiv!(L, B.val)
for b in 1:N
dB = N == 1 ? B.dval : B.dval[b]
if !(fact isa Const)
dL = N == 1 ? fact.dval.L : fact.dval[b].L
mul!(dB, dL, B.val, -1, 1)
end
ldiv!(L, dB)
end
ldiv!(U, B.val)
for b in 1:N
dB = N == 1 ? B.dval : B.dval[b]
if !(fact isa Const)
dU = N == 1 ? fact.dval.U : fact.dval[b].U
mul!(dB, dU, B.val, -1, 1)
end
ldiv!(U, dB)
end

dretvals = ntuple(Val(N)) do b
Base.@_inline_meta

dB = if N == 1
B.dval
else
B.dval[b]
end

if !isa(fact, Const)

dfact = if N == 1
fact.dval
else
fact.dval[b]
end

tmp = dfact.U * retval
mul!(dB, dfact.L, tmp, -1, 1)
end

func.val(fact.val, dB; kwargs...)
dB = N == 1 ? B.dval : B.dval[b]
return dB
end

if RT <: Const
Expand All @@ -844,18 +851,21 @@ function EnzymeRules.forward(
end
end

function EnzymeRules.augmented_primal(
config,
func::Const{typeof(cholesky)},
RT::Type,
A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}};
kwargs...)
fact = if EnzymeRules.needs_primal(config)
function EnzymeRules.augmented_primal(config,
func::Const{typeof(cholesky)},
RT::Type,
A::Annotation{<:Union{Matrix,
LinearAlgebra.RealHermSym{<:Real,
<:Matrix}}};
kwargs...)
fact = if EnzymeRules.needs_primal(config) || !(RT <: Const)
cholesky(A.val; kwargs...)
else
nothing
end

fact_returned = EnzymeRules.needs_primal(config) ? fact : nothing

# dfact would be a dense matrix, prepare buffer
dfact = if RT <: Const
nothing
Expand All @@ -865,124 +875,135 @@ function EnzymeRules.augmented_primal(
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
Enzyme.make_zero(fact)
return Enzyme.make_zero(fact)
end
end
end
cache = if isa(A, Const)
nothing
else
dfact
end

return EnzymeRules.AugmentedReturn(fact, dfact, cache)
cache = isa(A, Const) ? nothing : (fact, dfact)
return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache)
end

function EnzymeRules.reverse(
config,
::Const{typeof(cholesky)},
RT::Type,
dfact,
A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}};
kwargs...)

function EnzymeRules.reverse(config,
::Const{typeof(cholesky)},
RT::Type,
cache,
A::Annotation{<:Union{Matrix,
LinearAlgebra.RealHermSym{<:Real,
<:Matrix}}};
kwargs...)
if !(RT <: Const) && !isa(A, Const)
fact, dfact = cache
dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval
dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact

for (dA, dfact) in zip(dAs, dfacts)
_dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA
if _dA !== dfact.factors
_dA .+= dfact.factors
Ā = _cholesky_pullback_shared_code(fact, dfact)
_dA .+= Ā
dfact.factors .= 0
end
end
end
return (nothing,)
end


# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
# ->
#
# B(out)=inv(A) B(in)
# dA −= z B(out)^T
# dB = z, where z = inv(A^T) dB
function EnzymeRules.augmented_primal(
config,
func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},

A::Annotation{<:Cholesky},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
)
func.val(A.val, B.val; kwargs...)

cache_Bout = if !isa(A, Const) && !isa(B, Const)
if EnzymeRules.overwritten(config)[3]
copy(B.val)
else
B.val
end
# Adapted from ChainRules.jl
# MIT "Expat" License
# Copyright (c) 2018: Jarrett Revels.
# https://github.com/JuliaDiff/ChainRules.jl/blob/9f1817a22404259113e230bef149a54d379a660b/src/rulesets/LinearAlgebra/factorization.jl#L507-L528
function _cholesky_pullback_shared_code(C, ΔC)
Ā = similar(C.factors)
if C.uplo === 'U'
U = C.U
Ū = ΔC.U
mul!(Ā, Ū, U')
LinearAlgebra.copytri!(Ā, 'U', true)
eltype(Ā) <: Real || _realifydiag!(Ā)
ldiv!(U, Ā)
rdiv!(Ā, U')
Ā .+= tril!(ΔC.factors, -1)' # correction for unused triangle
triu!(Ā)
else # C.uplo === 'L'
L = C.L
L̄ = ΔC.L
mul!(Ā, L', L̄)
LinearAlgebra.copytri!(Ā, 'L', true)
eltype(Ā) <: Real || _realifydiag!(Ā)
rdiv!(Ā, L)
ldiv!(L', Ā)
Ā .+= triu!(ΔC.factors, 1)' # correction for unused triangle
tril!(Ā)
end
idx = diagind(Ā)
@views Ā[idx] .= real.(Ā[idx]) ./ 2
return Ā
end

function _realifydiag!(A)
for i in diagind(A)
@inbounds A[i] = real(A[i])
end
return A
end

function EnzymeRules.augmented_primal(config,
func::Const{typeof(ldiv!)},
RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated,
BatchDuplicatedNoNeed,
BatchDuplicated}},
A::Annotation{<:Cholesky},
B::Union{Const,DuplicatedNoNeed,Duplicated,
BatchDuplicatedNoNeed,BatchDuplicated};
kwargs...)
cache_B = if !isa(A, Const) && !isa(B, Const)
copy(B.val)
else
nothing
end

cache_A = if !isa(B, Const)
if EnzymeRules.overwritten(config)[2]
copy(A.val)
else
A.val
end
else
nothing
end

primal = if EnzymeRules.needs_primal(config)
B.val
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
B.dval
EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val
else
nothing
end

return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout))
primal = EnzymeRules.needs_primal(config) ? B.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? B.dval : nothing
func.val(A.val, B.val; kwargs...)
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B))
end

function EnzymeRules.reverse(
config,
func::Const{typeof(ldiv!)},
dret,
cache,
A::Annotation{<:Cholesky},
B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
kwargs...
)
function EnzymeRules.reverse(config,
simsurace marked this conversation as resolved.
Show resolved Hide resolved
func::Const{typeof(ldiv!)},
dret,
cache,
A::Annotation{<:Cholesky},
B::Union{Const,DuplicatedNoNeed,Duplicated,
BatchDuplicatedNoNeed,BatchDuplicated};
kwargs...)
if !isa(B, Const)

(cache_A, cache_Bout) = cache

(cache_A, cache_B) = cache
U = cache_A.U
Z = isa(A, Const) ? nothing : U' \ cache_B
Y = isa(A, Const) ? nothing : U \ Z
for b in 1:EnzymeRules.width(config)

dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b]

# dB = z, where z = inv(A^T) dB
# dA −= z B(out)^T

dZ = U' \ dB
func.val(cache_A, dB; kwargs...)
if !isa(A, Const)
∂B = U \ dZ
Ā = -dZ * Y' - Z * ∂B'
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
mul!(dA.factors, dB, transpose(cache_Bout), -1, 1)
if A.val.uplo === 'U'
dA.factors .+= UpperTriangular(Ā)
else
dA.factors .+= LowerTriangular(Ā')
end

end
end
end

return (nothing, nothing)
end
Loading
Loading