Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Widen cholesky rule to Hermitian and Symmetric matrices #1273

Merged
merged 3 commits into from
Feb 10, 2024

Conversation

devmotion
Copy link
Contributor

Seems to fix #1272:

julia> using Enzyme, LinearAlgebra

julia> g(C, X) = sum(C \ X)
g (generic function with 1 method)

julia> g2(A, X) = g(cholesky(A * A' + I), X)
g2 (generic function with 1 method)

julia> A = rand(2, 2);

julia> X = rand(2, 2);

julia> ∂g2_∂A = zero(A);

julia> ∂g2_∂X = zero(X);

julia> autodiff(Reverse, g2, Active, Duplicated(A, ∂g2_∂A), Duplicated(X, ∂g2_∂X))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)

julia> g3(A, X) = g(cholesky(Symmetric(A * A' + I)), X)
g3 (generic function with 1 method)

julia> ∂g3_∂A = zero(A);

julia> ∂g3_∂X = zero(X);

julia> autodiff(Reverse, g3, Active, Duplicated(A, ∂g3_∂A), Duplicated(X, ∂g3_∂X))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)

julia> ∂g3_∂A == ∂g2_∂A
true

julia> ∂g3_∂X == ∂g2_∂X
true

However, the big caveat is that I don't know if this is the correct approach.

A::Annotation{AT};
kwargs...) where {AT <: Array}
A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}};
kwargs...)

if !(RT <: Const) && !isa(A, Const)
dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval
dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact

for (dA, dfact) in zip(dAs, dfacts)
if dA !== dfact.factors
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering, should this check be changed as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the check should make sure we're not adding to itself

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix that check otherwise lgtm

@codecov-commenter
Copy link

codecov-commenter commented Feb 7, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (1390045) 75.14% compared to head (d9042ad) 75.21%.

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1273      +/-   ##
==========================================
+ Coverage   75.14%   75.21%   +0.06%     
==========================================
  Files          35       35              
  Lines       10393    10394       +1     
==========================================
+ Hits         7810     7818       +8     
+ Misses       2583     2576       -7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@devmotion
Copy link
Contributor Author

I updated the check 🙂

@wsmoses
Copy link
Member

wsmoses commented Feb 8, 2024

The 1.10 failure is new/real, but not caused by this PR -- but rather only caused [and first noticed] by a forward mode test added here.

The fix requires a jll bump which is currently blocked by: JuliaPackaging/Yggdrasil#8017. I've said on that thread that for now we should just disable Windows for LLVM16+ as hopefully one of the libLLVM_jll maintainers can fix windows on 16.

@wsmoses wsmoses merged commit aefaaa2 into EnzymeAD:main Feb 10, 2024
39 of 43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Differentiating cholesky(::Symmetric) throws an error
3 participants