From fa2c92e28a612a147afc71617cc2a2c65d4abee0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 25 Feb 2024 00:58:06 +0100 Subject: [PATCH 01/61] Add regression test --- test/internal_rules.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index fe65ce6279..000fe643f9 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -386,6 +386,14 @@ end dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) @test isapprox(dA, dA_sym) end + @testset "Regression test for #" for TE in (Float64, ComplexF64) + function f(A) + C = cholesky(A * A') + return sum(abs2, C.L * C.U) + end + A = rand(TE, 3, 3) + test_reverse(f, Active, (A, Duplicated)) + end end @testset "Linear solve for triangular matrices" begin From fb47774ddfacfa0bcd32890ca74de570eab31c4e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 25 Feb 2024 01:45:02 +0100 Subject: [PATCH 02/61] Improve test --- test/internal_rules.jl | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 000fe643f9..3d59312d95 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -386,13 +386,39 @@ end dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) @test isapprox(dA, dA_sym) end - @testset "Regression test for #" for TE in (Float64, ComplexF64) + @testset "Regression test for #" begin function f(A) - C = cholesky(A * A') + C = cholesky(A * adjoint(A)) return sum(abs2, C.L * C.U) end - A = rand(TE, 3, 3) - test_reverse(f, Active, (A, Duplicated)) + @testset for TE in (Float64, ComplexF64) + A = rand(TE, 3, 3) + test_reverse(f, Active, (A, Duplicated)) + @testset "Compare against function bypassing `cholesky`" begin + g(A) = sum(abs2, A * adjoint(A)) + @testset "Without wrapper" begin + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, f, Active, Duplicated(A, dA1)) + autodiff(Reverse, g, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 + end + if TE == Float64 + @testset "$wrapper wrapper" for wrapper in (Symmetric, Hermitian,) + function fw(A) + C = cholesky(wrapper(A * adjoint(A))) + return sum(abs2, C.L * C.U) + end + gw(A) = sum(abs2, wrapper(A * adjoint(A))) + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, fw, Active, Duplicated(A, dA1)) + autodiff(Reverse, gw, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 + end + end + end + end end end From cc6fdb639691d7e6611624e005570a4f12433128 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 25 Feb 2024 01:48:06 +0100 Subject: [PATCH 03/61] Fix regression test --- src/internal_rules.jl | 49 ++++++++++++++++++++++++++++++++++++++---- test/internal_rules.jl | 2 +- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index bea48baed4..4e5c0ac6b2 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -764,7 +764,7 @@ function EnzymeRules.augmented_primal( cache = if isa(A, Const) nothing else - dfact + (fact, dfact) end return EnzymeRules.AugmentedReturn(fact, dfact, cache) @@ -774,10 +774,10 @@ function EnzymeRules.reverse( config, ::Const{typeof(cholesky)}, RT::Type, - dfact, + cache, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; kwargs...) - + fact, dfact = cache if !(RT <: Const) && !isa(A, Const) dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact @@ -785,7 +785,13 @@ function EnzymeRules.reverse( for (dA, dfact) in zip(dAs, dfacts) _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA if _dA !== dfact.factors - _dA .+= dfact.factors + Ā = _cholesky_pullback_shared_code(fact, dfact) + if dA isa LinearAlgebra.RealHermSym + rmul!(Ā, one(eltype(Ā)) / 2) + else + Ā ./= 2 + end + _dA .+= Ā dfact.factors .= 0 end end @@ -793,6 +799,41 @@ function EnzymeRules.reverse( return (nothing,) end +# Taken from ChainRules.jl +function _cholesky_pullback_shared_code(C, ΔC) + Δfactors = ΔC.factors + Ā = similar(C.factors) + if C.uplo === 'U' + U = C.U + Ū = eltype(U) <: Real ? real(_maybeUpperTri(Δfactors)) : _maybeUpperTri(Δfactors) + mul!(Ā, Ū, U') + LinearAlgebra.copytri!(Ā, 'U', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + ldiv!(U, Ā) + rdiv!(Ā, U') + else # C.uplo === 'L' + L = C.L + L̄ = eltype(L) <: Real ? real(_maybeLowerTri(Δfactors)) : _maybeLowerTri(Δfactors) + mul!(Ā, L', L̄) + LinearAlgebra.copytri!(Ā, 'L', true) + eltype(Ā) <: Real || _realifydiag!(Ā) + rdiv!(Ā, L) + ldiv!(L', Ā) + end + return Ā +end + +_maybeUpperTri(A) = UpperTriangular(A) +_maybeUpperTri(A::Diagonal) = A +_maybeLowerTri(A) = LowerTriangular(A) +_maybeLowerTri(A::Diagonal) = A + +function _realifydiag!(A) + for i in diagind(A) + @inbounds A[i] = real(A[i]) + end + return A +end # y=inv(A) B # dA −= z y^T diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 3d59312d95..1d83302941 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -386,7 +386,7 @@ end dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) @test isapprox(dA, dA_sym) end - @testset "Regression test for #" begin + @testset "Regression test for #1307" begin function f(A) C = cholesky(A * adjoint(A)) return sum(abs2, C.L * C.U) From bf6461ea975098dceffa99adbecfd31d60c02061 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 26 Feb 2024 12:05:28 +0100 Subject: [PATCH 04/61] Fix reverse rules and tests --- src/internal_rules.jl | 65 +++++++++++------------------------------- test/internal_rules.jl | 49 +++++-------------------------- 2 files changed, 24 insertions(+), 90 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 4e5c0ac6b2..1d69f3af7d 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -835,14 +835,6 @@ function _realifydiag!(A) return A end -# y=inv(A) B -# dA −= z y^T -# dB += z, where z = inv(A^T) dy -# -> -# -# B(out)=inv(A) B(in) -# dA −= z B(out)^T -# dB = z, where z = inv(A^T) dB function EnzymeRules.augmented_primal( config, func::Const{typeof(ldiv!)}, @@ -852,41 +844,22 @@ function EnzymeRules.augmented_primal( B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; kwargs... ) - func.val(A.val, B.val; kwargs...) - - cache_Bout = if !isa(A, Const) && !isa(B, Const) - if EnzymeRules.overwritten(config)[3] - copy(B.val) - else - B.val - end + cache_B = if !isa(A, Const) && !isa(B, Const) + EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val else nothing end cache_A = if !isa(B, Const) - if EnzymeRules.overwritten(config)[2] - copy(A.val) - else - A.val - end - else - nothing - end - - primal = if EnzymeRules.needs_primal(config) - B.val + EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val else nothing end - shadow = if EnzymeRules.needs_shadow(config) - B.dval - else - nothing - end - - return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout)) + primal = EnzymeRules.needs_primal(config) ? B.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? B.dval : nothing + func.val(A.val, B.val; kwargs...) + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) end function EnzymeRules.reverse( @@ -898,24 +871,20 @@ function EnzymeRules.reverse( B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; kwargs... ) - if !isa(B, Const) - - (cache_A, cache_Bout) = cache - + if !isa(B, Const) && !isa(A, Const) + (cache_A, cache_B) = cache + U = cache_A.U + Y = B.val + Z = U' \ cache_B for b in 1:EnzymeRules.width(config) - dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] - - # dB = z, where z = inv(A^T) dB - # dA −= z B(out)^T - + dZ = U' \ dB + ∂B = U \ dZ func.val(cache_A, dB; kwargs...) - if !isa(A, Const) - dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - mul!(dA.factors, dB, transpose(cache_Bout), -1, 1) - end + Ā = -dZ * Y' - Z * ∂B' + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + dA.factors .+= UpperTriangular(Ā) end end - return (nothing, nothing) end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 1d83302941..c286f3c87c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -174,13 +174,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) + Enzyme.autodiff(Forward, driver, dx, dA, db) adJ[i, :] = dx.dval end return adJ @@ -195,13 +189,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) db.dval[i] = 1.0 - Enzyme.autodiff( - Forward, - driver, - dx, - A, - db - ) + Enzyme.autodiff(Forward, driver, dx, dA, db) adJ[i, :] = dx.dval end return adJ @@ -218,13 +206,7 @@ end dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) db = BatchDuplicated(b, ntuple(i -> seed(i), n)) dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) - Enzyme.autodiff( - Forward, - driver, - dx, - dA, - db - ) + Enzyme.autodiff(Forward, driver, dx, dA, db) for i in 1:n adJ[i, :] = dx.dval[i] end @@ -244,13 +226,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) + Enzyme.autodiff(Reverse, driver, dx, dA, db) adJ[i, :] = db.dval end return adJ @@ -265,13 +241,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) dx.dval[i] = 1.0 - Enzyme.autodiff( - Reverse, - driver, - dx, - A, - db - ) + Enzyme.autodiff(Reverse, driver, dx, A, db) adJ[i, :] = db.dval end return adJ @@ -288,13 +258,7 @@ end dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) - Enzyme.autodiff( - Reverse, - driver, - dx, - dA, - db - ) + Enzyme.autodiff(Reverse, driver, dx, dA, db) for i in 1:n adJ[i, :] .= db.dval[i] end @@ -384,6 +348,7 @@ end Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) + dA_sym = (dA_sym + dA_sym') / 2 @test isapprox(dA, dA_sym) end @testset "Regression test for #1307" begin From 885a38fbf898f837cc4dbb37a75ae6803bcdf1a5 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 26 Feb 2024 13:25:24 +0100 Subject: [PATCH 05/61] Fix typo --- test/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index c286f3c87c..92657455cb 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -189,7 +189,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) db.dval[i] = 1.0 - Enzyme.autodiff(Forward, driver, dx, dA, db) + Enzyme.autodiff(Forward, driver, dx, A, db) adJ[i, :] = dx.dval end return adJ From f53d1de95e2e3d6f671bcf86cdbcf11d510fd53d Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 26 Feb 2024 16:46:15 +0100 Subject: [PATCH 06/61] Fix all tests --- src/internal_rules.jl | 20 +++++++++++--------- test/internal_rules.jl | 39 ++++++++------------------------------- 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1d69f3af7d..ad699ef33c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -645,7 +645,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta - zeros(A.val) + zero(A.val) end else if N == 1 @@ -777,8 +777,8 @@ function EnzymeRules.reverse( cache, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; kwargs...) - fact, dfact = cache if !(RT <: Const) && !isa(A, Const) + fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact @@ -871,19 +871,21 @@ function EnzymeRules.reverse( B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; kwargs... ) - if !isa(B, Const) && !isa(A, Const) + if !isa(B, Const) (cache_A, cache_B) = cache - U = cache_A.U Y = B.val - Z = U' \ cache_B + U = cache_A.U + Z = isa(A, Const) ? nothing : U' \ cache_B for b in 1:EnzymeRules.width(config) dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] dZ = U' \ dB - ∂B = U \ dZ func.val(cache_A, dB; kwargs...) - Ā = -dZ * Y' - Z * ∂B' - dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - dA.factors .+= UpperTriangular(Ā) + if !isa(A, Const) + ∂B = U \ dZ + Ā = -dZ * Y' - Z * ∂B' + dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] + dA.factors .+= UpperTriangular(Ā) + end end end return (nothing, nothing) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 92657455cb..dde84f332c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -131,35 +131,12 @@ end return nothing end - function divdriver(x, A, b) - fact = cholesky(A) - divdriver_NC(x, fact, b) - end - - function divdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - divdriver_NC(x, fact, b) - end - - function divdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - divdriver_NC(x, fact, b) - end - - function ldivdriver(x, A, b) - fact = cholesky(A) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_herm(x, A, b) - fact = cholesky(Hermitian(A)) - ldivdriver_NC(x, fact, b) - end - - function ldivdriver_sym(x, A, b) - fact = cholesky(Symmetric(A)) - ldivdriver_NC(x, fact, b) - end + divdriver(x, A, b) = divdriver_NC(x, cholesky(A), b) + divdriver_herm(x, A, b) = divdriver_NC(x, cholesky(Hermitian(A)), b) + divdriver_sym(x, A, b) = divdriver_NC(x, cholesky(Symmetric(A)), b) + ldivdriver(x, A, b) = ldivdriver_NC(x, cholesky(A), b) + ldivdriver_herm(x, A, b) = ldivdriver_NC(x, cholesky(Hermitian(A)), b) + ldivdriver_sym(x, A, b) = ldivdriver_NC(x, cholesky(Symmetric(A)), b) # Test forward function fwdJdxdb(driver, A, b) @@ -189,7 +166,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) db.dval[i] = 1.0 - Enzyme.autodiff(Forward, driver, dx, A, db) + Enzyme.autodiff(Forward, driver, dx, Const(A), db) adJ[i, :] = dx.dval end return adJ @@ -241,7 +218,7 @@ end fill!(db.dval, 0.0) fill!(dx.dval, 0.0) dx.dval[i] = 1.0 - Enzyme.autodiff(Reverse, driver, dx, A, db) + Enzyme.autodiff(Reverse, driver, dx, Const(A), db) adJ[i, :] = db.dval end return adJ From af3c75049d01273969c97c58304310ec947b71a5 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 26 Feb 2024 16:58:56 +0100 Subject: [PATCH 07/61] Simplify helper functions --- test/internal_rules.jl | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index dde84f332c..bb577b195d 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -145,11 +145,9 @@ end db = Duplicated(b, zeros(length(b))) dx = Duplicated(zeros(length(b)), zeros(length(b))) for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) + dA.dval .= 0.0 + db.dval .= 0.0 + dx.dval .= 0.0 db.dval[i] = 1.0 Enzyme.autodiff(Forward, driver, dx, dA, db) adJ[i, :] = dx.dval @@ -162,9 +160,8 @@ end db = Duplicated(b, zeros(length(b))) dx = Duplicated(zeros(length(b)), zeros(length(b))) for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) + db.dval .= 0.0 + dx.dval .= 0.0 db.dval[i] = 1.0 Enzyme.autodiff(Forward, driver, dx, Const(A), db) adJ[i, :] = dx.dval @@ -197,11 +194,9 @@ end db = Duplicated(b, zeros(length(b))) dx = Duplicated(zeros(length(b)), zeros(length(b))) for i in 1:length(b) - copyto!(dA.val, A) - copyto!(db.val, b) - fill!(dA.dval, 0.0) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) + dA.dval .= 0.0 + db.dval .= 0.0 + dx.dval .= 0.0 dx.dval[i] = 1.0 Enzyme.autodiff(Reverse, driver, dx, dA, db) adJ[i, :] = db.dval @@ -214,9 +209,8 @@ end db = Duplicated(b, zeros(length(b))) dx = Duplicated(zeros(length(b)), zeros(length(b))) for i in 1:length(b) - copyto!(db.val, b) - fill!(db.dval, 0.0) - fill!(dx.dval, 0.0) + db.dval .= 0.0 + dx.dval .= 0.0 dx.dval[i] = 1.0 Enzyme.autodiff(Reverse, driver, dx, Const(A), db) adJ[i, :] = db.dval @@ -243,8 +237,6 @@ end end function Jdxdb(driver, A, b) - x = A\b - dA = zeros(size(A)) db = zeros(length(b)) J = zeros(length(b), length(b)) for i in 1:length(b) From 50049ee69d6081228fd5eb5925a0532718fd61c9 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 26 Feb 2024 17:47:01 +0100 Subject: [PATCH 08/61] Simplify rule --- src/internal_rules.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index ad699ef33c..c2f3be763c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -786,12 +786,7 @@ function EnzymeRules.reverse( _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA if _dA !== dfact.factors Ā = _cholesky_pullback_shared_code(fact, dfact) - if dA isa LinearAlgebra.RealHermSym - rmul!(Ā, one(eltype(Ā)) / 2) - else - Ā ./= 2 - end - _dA .+= Ā + _dA .+= Ā ./ 2 dfact.factors .= 0 end end From cde0c4383d18122da241e508c4e8e63600b53efc Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 07:53:24 +0100 Subject: [PATCH 09/61] Apply suggestions from review --- src/internal_rules.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index c2f3be763c..ffecb1b9fd 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -742,11 +742,9 @@ function EnzymeRules.augmented_primal( RT::Type, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; kwargs...) - fact = if EnzymeRules.needs_primal(config) - cholesky(A.val; kwargs...) - else - nothing - end + + fact = cholesky(A.val; kwargs...) + fact_returned = EnzymeRules.needs_primal(config) ? fact : nothing # dfact would be a dense matrix, prepare buffer dfact = if RT <: Const @@ -761,12 +759,8 @@ function EnzymeRules.augmented_primal( end end end - cache = if isa(A, Const) - nothing - else - (fact, dfact) - end + cache = isa(A, Const) : nothing ? (fact_returned, dfact) return EnzymeRules.AugmentedReturn(fact, dfact, cache) end From 62d2278ce7384488fa225f68d610666cef9d007a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 07:53:31 +0100 Subject: [PATCH 10/61] Fix typo --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index ffecb1b9fd..5112422d52 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -760,7 +760,7 @@ function EnzymeRules.augmented_primal( end end - cache = isa(A, Const) : nothing ? (fact_returned, dfact) + cache = isa(A, Const) ? nothing : (fact_returned, dfact) return EnzymeRules.AugmentedReturn(fact, dfact, cache) end From 3d805f355081702b2ee332cdec21d41df874b365 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 07:53:43 +0100 Subject: [PATCH 11/61] Change testset name and add explanations --- test/internal_rules.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index bb577b195d..b0cdd59c0c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -320,7 +320,8 @@ end dA_sym = (dA_sym + dA_sym') / 2 @test isapprox(dA, dA_sym) end - @testset "Regression test for #1307" begin + @testset "Unit test for `cholesky` (regression test for #1307)" begin + # This test checks the `cholesky` rule without involving `ldiv!` function f(A) C = cholesky(A * adjoint(A)) return sum(abs2, C.L * C.U) @@ -330,6 +331,9 @@ end test_reverse(f, Active, (A, Duplicated)) @testset "Compare against function bypassing `cholesky`" begin g(A) = sum(abs2, A * adjoint(A)) + # If C = cholesky(A * A'), we have A * A' == C.L * C.U, so `g` + # is essentially the same fucntion as `f`, but bypassing `cholesky`. + # We can therefore use this to check that we get the same derivatives. @testset "Without wrapper" begin dA1 = zero(A) dA2 = zero(A) From c012073a876477fc7b3a45abed57ab3a8f1ab1c0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 08:20:01 +0100 Subject: [PATCH 12/61] Add unit tests for forward rule --- test/internal_rules.jl | 48 ++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index b0cdd59c0c..3d0edd7117 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -321,25 +321,35 @@ end @test isapprox(dA, dA_sym) end @testset "Unit test for `cholesky` (regression test for #1307)" begin - # This test checks the `cholesky` rule without involving `ldiv!` + # This test checks the `cholesky` rules without involving `ldiv!` function f(A) C = cholesky(A * adjoint(A)) return sum(abs2, C.L * C.U) end @testset for TE in (Float64, ComplexF64) A = rand(TE, 3, 3) + test_forward(f, Duplicated, (A, Duplicated)) test_reverse(f, Active, (A, Duplicated)) @testset "Compare against function bypassing `cholesky`" begin g(A) = sum(abs2, A * adjoint(A)) - # If C = cholesky(A * A'), we have A * A' == C.L * C.U, so `g` - # is essentially the same fucntion as `f`, but bypassing `cholesky`. - # We can therefore use this to check that we get the same derivatives. + # If C = cholesky(A * A'), we have A * A' ≈ C.L * C.U, so `g` + # is essentially the same function as `f`, but bypassing `cholesky`. + # We can therefore use this to check that we get the correct derivatives. @testset "Without wrapper" begin - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, f, Active, Duplicated(A, dA1)) - autodiff(Reverse, g, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 + @testset "Forward mode" begin + dA = rand(TE, size(A)...) + d1 = autodiff(Forward, f, Duplicated, Duplicated(A, dA)) + d2 = autodiff(Forward, g, Duplicated, Duplicated(A, dA)) + @test all(d1 .≈ d2) + end + + @testset "Reverse mode" begin + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, f, Active, Duplicated(A, dA1)) + autodiff(Reverse, g, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 + end end if TE == Float64 @testset "$wrapper wrapper" for wrapper in (Symmetric, Hermitian,) @@ -348,11 +358,21 @@ end return sum(abs2, C.L * C.U) end gw(A) = sum(abs2, wrapper(A * adjoint(A))) - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, fw, Active, Duplicated(A, dA1)) - autodiff(Reverse, gw, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 + + @testset "Forward mode" begin + dA = rand(TE, size(A)...) + d1 = autodiff(Forward, fw, Duplicated, Duplicated(A, dA)) + d2 = autodiff(Forward, gw, Duplicated, Duplicated(A, dA)) + @test all(d1 .≈ d2) + end + + @testset "Reverse mode" begin + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, fw, Active, Duplicated(A, dA1)) + autodiff(Reverse, gw, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 + end end end end From 7ec83164f50c93090c87da8d365490b95707a721 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 10:45:39 +0100 Subject: [PATCH 13/61] Fix tests --- test/internal_rules.jl | 55 ++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 3d0edd7117..2188f1cd36 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -352,27 +352,40 @@ end end end if TE == Float64 - @testset "$wrapper wrapper" for wrapper in (Symmetric, Hermitian,) - function fw(A) - C = cholesky(wrapper(A * adjoint(A))) - return sum(abs2, C.L * C.U) - end - gw(A) = sum(abs2, wrapper(A * adjoint(A))) - - @testset "Forward mode" begin - dA = rand(TE, size(A)...) - d1 = autodiff(Forward, fw, Duplicated, Duplicated(A, dA)) - d2 = autodiff(Forward, gw, Duplicated, Duplicated(A, dA)) - @test all(d1 .≈ d2) - end - - @testset "Reverse mode" begin - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, fw, Active, Duplicated(A, dA1)) - autodiff(Reverse, gw, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 - end + function f_sym(A) + C = cholesky(Symmetric(A * adjoint(A))) + return sum(abs2, C.L * C.U) + end + g_sym(A) = sum(abs2, Symmetric(A * adjoint(A))) + function f_her(A) + C = cholesky(Hermitian(A * adjoint(A))) + return sum(abs2, C.L * C.U) + end + g_her(A) = sum(abs2, Hermitian(A * adjoint(A))) + + @testset "Forward mode" begin + dA = rand(TE, size(A)...) + d1 = autodiff(Forward, f_sym, Duplicated, Duplicated(A, dA)) + d2 = autodiff(Forward, g_sym, Duplicated, Duplicated(A, dA)) + @test all(d1 .≈ d2) + + d1 = autodiff(Forward, f_her, Duplicated, Duplicated(A, dA)) + d2 = autodiff(Forward, g_her, Duplicated, Duplicated(A, dA)) + @test all(d1 .≈ d2) + end + + @testset "Reverse mode" begin + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, f_sym, Active, Duplicated(A, dA1)) + autodiff(Reverse, g_sym, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 + + dA1 = zero(A) + dA2 = zero(A) + autodiff(Reverse, f_her, Active, Duplicated(A, dA1)) + autodiff(Reverse, g_her, Active, Duplicated(A, dA2)) + @test dA1 ≈ dA2 end end end From 32654296cc9253cbeb399b729679e9be1e02d125 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 27 Feb 2024 10:45:54 +0100 Subject: [PATCH 14/61] Fix forward rule --- src/internal_rules.jl | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 5112422d52..da9a216509 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -640,26 +640,18 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) else N = width(RT) - invL = inv(fact.L) - dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta zero(A.val) end else - if N == 1 - (A.dval,) - else - A.dval - end + N == 1 ? (A.dval,) : A.dval end dfact = ntuple(Val(N)) do i Base.@_inline_meta - Cholesky( - Matrix(fact.L * LowerTriangular(invL * dA[i] * invL' * 0.5 * I)), 'L', 0 - ) + return _cholesky_forward(fact, dA[i]) end if (RT <: DuplicatedNoNeed) || (RT <: BatchDuplicatedNoNeed) @@ -672,6 +664,30 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) end end +function _cholesky_forward(C::Cholesky, Σdot) + # Computes the cholesky forward mode update rule + # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf + if C.uplo == 'U' + U = C.U + Udot = Σdot / U + ldiv!(U', Udot) + idx = diagind(Udot) + Udot[idx] ./= 2 + triu!(Udot) + rmul!(Udot, U) + return Cholesky(Udot, 'U', 0) + else + L = C.L + Ldot = L \ Σdot + rdiv!(Ldot, L') + idx = diagind(Ldot) + Ldot[idx] ./= 2 + tril!(Ldot) + lmul!(L, Ldot) + return Cholesky(Ldot, 'L', 0) + end +end + # y = inv(A) B # dY = inv(A) [ dB - dA y ] # -> From b6a5fe8f165a07eda1079701460557e3ee21f17a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 3 Mar 2024 00:26:38 +0100 Subject: [PATCH 15/61] Fix cholesky reverse rule --- src/internal_rules.jl | 9 ++++++++- test/internal_rules.jl | 15 +++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index da9a216509..e1097b052c 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -796,7 +796,14 @@ function EnzymeRules.reverse( _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA if _dA !== dfact.factors Ā = _cholesky_pullback_shared_code(fact, dfact) - _dA .+= Ā ./ 2 + if A.val isa LinearAlgebra.RealHermSym{<:Real,<:Matrix} + rmul!(Ā, one(eltype(Ā)) / 2) + _dA .+= Ā + else + idx = diagind(Ā) + @views Ā[idx] .= real.(Ā[idx]) ./ 2 + _dA .+= UpperTriangular(Ā) + end dfact.factors .= 0 end end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 2188f1cd36..3fbf115ad6 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -260,7 +260,7 @@ end return J end - @testset "Testing $op" for (op, driver, driver_NC) in ( + @testset "Testing $op, $driver, $driver_NC" for (op, driver, driver_NC) in ( (:\, divdriver, divdriver_NC), (:\, divdriver_herm, divdriver_NC), (:\, divdriver_sym, divdriver_NC), @@ -312,13 +312,11 @@ end A = [1.3 0.5; 0.5 1.5] b = [1., 2.] - V = [1.0 0.0; 0.0 0.0] dA = zero(A) Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) - dA_sym = (dA_sym + dA_sym') / 2 - @test isapprox(dA, dA_sym) + @test isapprox((dA + dA') / 2, (dA_sym + dA_sym') / 2) end @testset "Unit test for `cholesky` (regression test for #1307)" begin # This test checks the `cholesky` rules without involving `ldiv!` @@ -391,6 +389,15 @@ end end end end + @testset "Linear solve with and without `cholesky`" begin + A = [3. 1.; 1. 2.] + b = [1., 2.] + dA1 = Duplicated(copy(A), zero(A)) + dA2 = Duplicated(copy(A), zero(A)) + autodiff(Reverse, (A, b) -> first(A\b), dA1, Const(b)) + autodiff(Reverse, (A, b) -> first(cholesky(A)\b), dA2, Const(b)) + @test dA1.dval ≈ dA2.dval + end end @testset "Linear solve for triangular matrices" begin From 91ee9763a73989ed8322d9815893a21d52199d38 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 5 Mar 2024 08:18:44 +0100 Subject: [PATCH 16/61] Remove wrong branch --- src/internal_rules.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index e1097b052c..e3c741693e 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -796,14 +796,9 @@ function EnzymeRules.reverse( _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA if _dA !== dfact.factors Ā = _cholesky_pullback_shared_code(fact, dfact) - if A.val isa LinearAlgebra.RealHermSym{<:Real,<:Matrix} - rmul!(Ā, one(eltype(Ā)) / 2) - _dA .+= Ā - else - idx = diagind(Ā) - @views Ā[idx] .= real.(Ā[idx]) ./ 2 - _dA .+= UpperTriangular(Ā) - end + idx = diagind(Ā) + @views Ā[idx] .= real.(Ā[idx]) ./ 2 + _dA .+= UpperTriangular(Ā) dfact.factors .= 0 end end From 61cf2b35c872dbd062439a1381b8b722f870db4f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:05:03 +0100 Subject: [PATCH 17/61] Simplify Ubar definition Co-authored-by: Seth Axen --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index ddcaa1db77..a68c68b973 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -812,7 +812,7 @@ function _cholesky_pullback_shared_code(C, ΔC) Ā = similar(C.factors) if C.uplo === 'U' U = C.U - Ū = eltype(U) <: Real ? real(_maybeUpperTri(Δfactors)) : _maybeUpperTri(Δfactors) + Ū = ΔC.U mul!(Ā, Ū, U') LinearAlgebra.copytri!(Ā, 'U', true) eltype(Ā) <: Real || _realifydiag!(Ā) From eaee9c5f41239580522749de502eb03f364015a4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:05:39 +0100 Subject: [PATCH 18/61] Simplify Lbar definition Co-authored-by: Seth Axen --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index a68c68b973..2e0ce28788 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -820,7 +820,7 @@ function _cholesky_pullback_shared_code(C, ΔC) rdiv!(Ā, U') else # C.uplo === 'L' L = C.L - L̄ = eltype(L) <: Real ? real(_maybeLowerTri(Δfactors)) : _maybeLowerTri(Δfactors) + L̄ = ΔC.L mul!(Ā, L', L̄) LinearAlgebra.copytri!(Ā, 'L', true) eltype(Ā) <: Real || _realifydiag!(Ā) From 7780fdee2ae5ead079adb2d94a679f555e73ac7f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:06:01 +0100 Subject: [PATCH 19/61] Remove redundant line Co-authored-by: Seth Axen --- src/internal_rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 2e0ce28788..40788792b1 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -808,7 +808,6 @@ end # Taken from ChainRules.jl function _cholesky_pullback_shared_code(C, ΔC) - Δfactors = ΔC.factors Ā = similar(C.factors) if C.uplo === 'U' U = C.U From 2bd2458916788e4d216c5fa392adf78ba71b4fcf Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:06:48 +0100 Subject: [PATCH 20/61] Remove _maybeUpperTri Co-authored-by: Seth Axen --- src/internal_rules.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 40788792b1..5af4a0e504 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -829,11 +829,6 @@ function _cholesky_pullback_shared_code(C, ΔC) return Ā end -_maybeUpperTri(A) = UpperTriangular(A) -_maybeUpperTri(A::Diagonal) = A -_maybeLowerTri(A) = LowerTriangular(A) -_maybeLowerTri(A::Diagonal) = A - function _realifydiag!(A) for i in diagind(A) @inbounds A[i] = real(A[i]) From bc78219c814fc4b8b67fc6388d42a2a746283db3 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:23:16 +0100 Subject: [PATCH 21/61] Update internal_rules.jl Co-authored-by: Seth Axen --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 5af4a0e504..1160ab72cb 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -667,7 +667,7 @@ end function _cholesky_forward(C::Cholesky, Σdot) # Computes the cholesky forward mode update rule # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf - if C.uplo == 'U' + if C.uplo === 'U' U = C.U Udot = Σdot / U ldiv!(U', Udot) From 9d6d45dfa8a96aca39b5a6afa55df4790e28a8d0 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:25:57 +0100 Subject: [PATCH 22/61] Improve notation Co-authored-by: Seth Axen --- src/internal_rules.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 1160ab72cb..047f8857ed 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -669,13 +669,13 @@ function _cholesky_forward(C::Cholesky, Σdot) # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf if C.uplo === 'U' U = C.U - Udot = Σdot / U - ldiv!(U', Udot) - idx = diagind(Udot) - Udot[idx] ./= 2 - triu!(Udot) - rmul!(Udot, U) - return Cholesky(Udot, 'U', 0) + U̇ = Ȧ / U + ldiv!(U', U̇) + idx = diagind(U̇) + U̇[idx] ./= 2 + triu!(U̇) + rmul!(U̇, U) + return Cholesky(U̇, 'U', C.info) else L = C.L Ldot = L \ Σdot From ea2ea2c94587b3e328fce358a47be5f3d185d660 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Wed, 6 Mar 2024 09:26:15 +0100 Subject: [PATCH 23/61] Improve notation Co-authored-by: Seth Axen --- src/internal_rules.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 047f8857ed..4afe5d2a96 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -678,13 +678,13 @@ function _cholesky_forward(C::Cholesky, Σdot) return Cholesky(U̇, 'U', C.info) else L = C.L - Ldot = L \ Σdot - rdiv!(Ldot, L') - idx = diagind(Ldot) - Ldot[idx] ./= 2 - tril!(Ldot) - lmul!(L, Ldot) - return Cholesky(Ldot, 'L', 0) + L̇ = L \ Ȧ + rdiv!(L̇, L') + idx = diagind(L̇) + L̇[idx] ./= 2 + tril!(L̇) + lmul!(L, L̇) + return Cholesky(L̇, 'L', C.info) end end From 8bfb459b96a8de7ac4e650954a38819a50e6b4dd Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 6 Mar 2024 17:18:00 +0100 Subject: [PATCH 24/61] Fix wrong argument name --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 4afe5d2a96..0899d4504f 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -664,7 +664,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) end end -function _cholesky_forward(C::Cholesky, Σdot) +function _cholesky_forward(C::Cholesky, Ȧ) # Computes the cholesky forward mode update rule # C.f. eq. 8 in https://arxiv.org/pdf/1602.07527.pdf if C.uplo === 'U' From 33e48afdd7ea345f69218ba6630d4a36f337e7c7 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 6 Mar 2024 17:34:42 +0100 Subject: [PATCH 25/61] Add copyright notice --- src/internal_rules.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 0899d4504f..68dfad43c0 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -806,7 +806,10 @@ function EnzymeRules.reverse( return (nothing,) end -# Taken from ChainRules.jl +# Adapted from ChainRules.jl +# MIT "Expat" License +# Copyright (c) 2018: Jarrett Revels. +# https://github.com/JuliaDiff/ChainRules.jl/blob/9f1817a22404259113e230bef149a54d379a660b/src/rulesets/LinearAlgebra/factorization.jl#L507-L528 function _cholesky_pullback_shared_code(C, ΔC) Ā = similar(C.factors) if C.uplo === 'U' From c2599803b16734145ed5cf1bb31568b34c033c53 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 7 Mar 2024 13:06:24 +0100 Subject: [PATCH 26/61] Implement changes suggested in review --- src/internal_rules.jl | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 68dfad43c0..697ad00eab 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -759,7 +759,12 @@ function EnzymeRules.augmented_primal( A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; kwargs...) - fact = cholesky(A.val; kwargs...) + fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) + cholesky(A.val; kwargs...) + else + nothing + end + fact_returned = EnzymeRules.needs_primal(config) ? fact : nothing # dfact would be a dense matrix, prepare buffer @@ -776,8 +781,8 @@ function EnzymeRules.augmented_primal( end end - cache = isa(A, Const) ? nothing : (fact_returned, dfact) - return EnzymeRules.AugmentedReturn(fact, dfact, cache) + cache = isa(A, Const) ? nothing : (fact, dfact) + return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache) end function EnzymeRules.reverse( From cf5e57357372b0dad8c30ffd2997073682729932 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 17 Mar 2024 14:34:33 +0100 Subject: [PATCH 27/61] Rewrite tests from scratch - real case --- test/internal_rules.jl | 306 +++-------------------------------------- 1 file changed, 22 insertions(+), 284 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 5481fd4da6..25d60586e5 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -112,291 +112,29 @@ end @static if VERSION > v"1.8" @testset "Cholesky" begin - function symmetric_definite(n :: Int=10) - α = one(Float64) - A = spdiagm(-1 => α * ones(n-1), 0 => 4 * ones(n), 1 => conj(α) * ones(n-1)) - b = A * Float64[1:n;] - return A, b - end - - function divdriver_NC(x, fact, b) - res = fact\b - x .= res - return nothing - end - - function ldivdriver_NC(x, fact, b) - ldiv!(fact,b) - x .= b - return nothing - end - - divdriver(x, A, b) = divdriver_NC(x, cholesky(A), b) - divdriver_herm(x, A, b) = divdriver_NC(x, cholesky(Hermitian(A)), b) - divdriver_sym(x, A, b) = divdriver_NC(x, cholesky(Symmetric(A)), b) - ldivdriver(x, A, b) = ldivdriver_NC(x, cholesky(A), b) - ldivdriver_herm(x, A, b) = ldivdriver_NC(x, cholesky(Hermitian(A)), b) - ldivdriver_sym(x, A, b) = ldivdriver_NC(x, cholesky(Symmetric(A)), b) - - # Test forward - function fwdJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - dA.dval .= 0.0 - db.dval .= 0.0 - dx.dval .= 0.0 - db.dval[i] = 1.0 - Enzyme.autodiff(Forward, driver, dx, dA, db) - adJ[i, :] = dx.dval - end - return adJ - end - - function const_fwdJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - db.dval .= 0.0 - dx.dval .= 0.0 - db.dval[i] = 1.0 - Enzyme.autodiff(Forward, driver, dx, Const(A), db) - adJ[i, :] = dx.dval - end - return adJ - end - - function batchedfwdJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> seed(i), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> zeros(length(b)), n)) - Enzyme.autodiff(Forward, driver, dx, dA, db) - for i in 1:n - adJ[i, :] = dx.dval[i] + function cholesky_testfunction(A, b, x1, x2) + C1 = cholesky(A * A') + C2 = cholesky(Symmetric(A * A')) + x1 .= C1 \ b + x2 .= C2 \ b + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) + end + A = rand(5, 5) + b = rand(5) + x1 = rand(5) + x2 = rand(5) + @testset for TA in (Const, Duplicated), + Tb in (Const, Duplicated), + Tx1 in (Const, Duplicated), + Tx2 in (Const, Duplicated) + @testset for Tret in (Const, Duplicated) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_forward(cholesky_testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + @testset for Tret in (Const, Active) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_reverse(cholesky_testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) end - return adJ - end - - # Test reverse - function revJdxdb(driver, A, b) - adJ = zeros(size(A)) - dA = Duplicated(A, zeros(size(A))) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - dA.dval .= 0.0 - db.dval .= 0.0 - dx.dval .= 0.0 - dx.dval[i] = 1.0 - Enzyme.autodiff(Reverse, driver, dx, dA, db) - adJ[i, :] = db.dval - end - return adJ - end - - function const_revJdxdb(driver, A, b) - adJ = zeros(length(b), length(b)) - db = Duplicated(b, zeros(length(b))) - dx = Duplicated(zeros(length(b)), zeros(length(b))) - for i in 1:length(b) - db.dval .= 0.0 - dx.dval .= 0.0 - dx.dval[i] = 1.0 - Enzyme.autodiff(Reverse, driver, dx, Const(A), db) - adJ[i, :] = db.dval - end - return adJ - end - - function batchedrevJdxdb(driver, A, b) - n = length(b) - function seed(i) - x = zeros(n) - x[i] = 1.0 - return x - end - adJ = zeros(size(A)) - dA = BatchDuplicated(A, ntuple(i -> zeros(size(A)), n)) - db = BatchDuplicated(b, ntuple(i -> zeros(length(b)), n)) - dx = BatchDuplicated(zeros(length(b)), ntuple(i -> seed(i), n)) - Enzyme.autodiff(Reverse, driver, dx, dA, db) - for i in 1:n - adJ[i, :] .= db.dval[i] - end - return adJ - end - - function Jdxdb(driver, A, b) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx - end - return J - end - - function JdxdA(driver, A, b) - db = zeros(length(b)) - J = zeros(length(b), length(b)) - for i in 1:length(b) - db[i] = 1.0 - dx = A\db - db[i] = 0.0 - J[i, :] = dx - end - return J - end - - @testset "Testing $op, $driver, $driver_NC" for (op, driver, driver_NC) in ( - (:\, divdriver, divdriver_NC), - (:\, divdriver_herm, divdriver_NC), - (:\, divdriver_sym, divdriver_NC), - (:ldiv!, ldivdriver, ldivdriver_NC), - (:ldiv!, ldivdriver_herm, ldivdriver_NC), - (:ldiv!, ldivdriver_sym, ldivdriver_NC) - ) - A, b = symmetric_definite(10) - n = length(b) - A = Matrix(A) - x = zeros(n) - x = driver(x, A, b) - fdm = forward_fdm(2, 1); - - function b_one(b) - _x = zeros(length(b)) - driver(_x,A,b) - return _x - end - - fdJ = op==:\ ? FiniteDifferences.jacobian(fdm, b_one, copy(b))[1] : nothing - fwdJ = fwdJdxdb(driver, A, b) - revJ = revJdxdb(driver, A, b) - batchedrevJ = batchedrevJdxdb(driver, A, b) - batchedfwdJ = batchedfwdJdxdb(driver, A, b) - J = Jdxdb(driver, A, b) - - if op == :\ - @test isapprox(fwdJ, fdJ) - end - - @test isapprox(fwdJ, revJ) - @test isapprox(fwdJ, batchedrevJ) - @test isapprox(fwdJ, batchedfwdJ) - - fwdJ = const_fwdJdxdb(driver_NC, cholesky(A), b) - revJ = const_revJdxdb(driver_NC, cholesky(A), b) - if op == :\ - @test isapprox(fwdJ, fdJ) - end - @test isapprox(fwdJ, revJ) - - function h(A, b) - C = cholesky(A) - b2 = copy(b) - ldiv!(C, b2) - @inbounds b2[1] - end - - A = [1.3 0.5; 0.5 1.5] - b = [1., 2.] - dA = zero(A) - Enzyme.autodiff(Reverse, h, Active, Duplicated(A, dA), Const(b)) - - dA_sym = - (transpose(A) \ [1.0, 0.0]) * transpose(A \ b) - @test isapprox((dA + dA') / 2, (dA_sym + dA_sym') / 2) - end - @testset "Unit test for `cholesky` (regression test for #1307)" begin - # This test checks the `cholesky` rules without involving `ldiv!` - function f(A) - C = cholesky(A * adjoint(A)) - return sum(abs2, C.L * C.U) - end - @testset for TE in (Float64, ComplexF64) - A = rand(TE, 3, 3) - test_forward(f, Duplicated, (A, Duplicated)) - test_reverse(f, Active, (A, Duplicated)) - @testset "Compare against function bypassing `cholesky`" begin - g(A) = sum(abs2, A * adjoint(A)) - # If C = cholesky(A * A'), we have A * A' ≈ C.L * C.U, so `g` - # is essentially the same function as `f`, but bypassing `cholesky`. - # We can therefore use this to check that we get the correct derivatives. - @testset "Without wrapper" begin - @testset "Forward mode" begin - dA = rand(TE, size(A)...) - d1 = autodiff(Forward, f, Duplicated, Duplicated(A, dA)) - d2 = autodiff(Forward, g, Duplicated, Duplicated(A, dA)) - @test all(d1 .≈ d2) - end - - @testset "Reverse mode" begin - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, f, Active, Duplicated(A, dA1)) - autodiff(Reverse, g, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 - end - end - if TE == Float64 - function f_sym(A) - C = cholesky(Symmetric(A * adjoint(A))) - return sum(abs2, C.L * C.U) - end - g_sym(A) = sum(abs2, Symmetric(A * adjoint(A))) - function f_her(A) - C = cholesky(Hermitian(A * adjoint(A))) - return sum(abs2, C.L * C.U) - end - g_her(A) = sum(abs2, Hermitian(A * adjoint(A))) - - @testset "Forward mode" begin - dA = rand(TE, size(A)...) - d1 = autodiff(Forward, f_sym, Duplicated, Duplicated(A, dA)) - d2 = autodiff(Forward, g_sym, Duplicated, Duplicated(A, dA)) - @test all(d1 .≈ d2) - - d1 = autodiff(Forward, f_her, Duplicated, Duplicated(A, dA)) - d2 = autodiff(Forward, g_her, Duplicated, Duplicated(A, dA)) - @test all(d1 .≈ d2) - end - - @testset "Reverse mode" begin - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, f_sym, Active, Duplicated(A, dA1)) - autodiff(Reverse, g_sym, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 - - dA1 = zero(A) - dA2 = zero(A) - autodiff(Reverse, f_her, Active, Duplicated(A, dA1)) - autodiff(Reverse, g_her, Active, Duplicated(A, dA2)) - @test dA1 ≈ dA2 - end - end - end - end - end - @testset "Linear solve with and without `cholesky`" begin - A = [3. 1.; 1. 2.] - b = [1., 2.] - dA1 = Duplicated(copy(A), zero(A)) - dA2 = Duplicated(copy(A), zero(A)) - autodiff(Reverse, (A, b) -> first(A\b), dA1, Const(b)) - autodiff(Reverse, (A, b) -> first(cholesky(A)\b), dA2, Const(b)) - @test dA1.dval ≈ dA2.dval end end From 3aba37044a0a412efe62b1039775c12551d522ca Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 17 Mar 2024 22:56:10 +0100 Subject: [PATCH 28/61] Complete tests --- test/internal_rules.jl | 55 ++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 25d60586e5..9cf2de03fd 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -112,28 +112,41 @@ end @static if VERSION > v"1.8" @testset "Cholesky" begin - function cholesky_testfunction(A, b, x1, x2) - C1 = cholesky(A * A') - C2 = cholesky(Symmetric(A * A')) - x1 .= C1 \ b - x2 .= C2 \ b - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) + function cholesky_testfunction_symmetric(A, b, x1, x2) + C1 = cholesky(A * A') # test factorization without wrapper + C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself end - A = rand(5, 5) - b = rand(5) - x1 = rand(5) - x2 = rand(5) - @testset for TA in (Const, Duplicated), - Tb in (Const, Duplicated), - Tx1 in (Const, Duplicated), - Tx2 in (Const, Duplicated) - @testset for Tret in (Const, Duplicated) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_forward(cholesky_testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) - end - @testset for Tret in (Const, Active) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_reverse(cholesky_testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + function cholesky_testfunction_hermitian(A, b, x1, x2) + C1 = cholesky(A * adjoint(A)) # test factorization without wrapper + C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + @testset for (TE, testfunction) in ( + Float64 => cholesky_testfunction_symmetric, + Float64 => cholesky_testfunction_hermitian + ) + @testset for TA in (Const, Duplicated), + Tb in (Const, Duplicated), + Tx1 in (Const, Duplicated), + Tx2 in (Const, Duplicated) + A = rand(TE, 5, 5) + b = rand(TE, 5) + x1 = rand(TE, 5) + x2 = rand(TE, 5) + # ishermitian(A * adjoint(A)) || continue + @testset for Tret in (Const, Duplicated) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + @testset for Tret in (Const, Active) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end end end end From 188949189853b361db258daea95748893515e4f3 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 20 Mar 2024 17:20:46 +0100 Subject: [PATCH 29/61] Run formatter --- src/internal_rules.jl | 427 ++++++++++++++++++++++++----------------- test/internal_rules.jl | 199 +++++++++++-------- 2 files changed, 371 insertions(+), 255 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 697ad00eab..da1553dcef 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -99,7 +99,9 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} +function EnzymeRules.inactive_noinl( + ::typeof(Base.setindex!), ::IdDict{K,V}, ::K, ::V +) where {K,V<:Integer} return nothing end @@ -117,35 +119,45 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true @inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T, N}) where {T, N} = N +@inline width(::BatchDuplicated{T,N}) where {T,N} = N @inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N +@inline width(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N -@inline width(::Type{Duplicated{T}}) where T = 1 -@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N +@inline width(::Type{Duplicated{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicated{T,N}}) where {T,N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward( + ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated +) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward( + ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T,N} +) where {T,N} ntuple(Val(N)) do _ deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:Union{Integer,Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -159,19 +171,28 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward( + func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated +) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward( + func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T,N} +) where {T,N} primal = func.val(x.val) - return BatchDuplicated(primal, ntuple(Val(N)) do i - deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) - end) + return BatchDuplicated( + primal, + ntuple(Val(N)) do i + deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) + end, + ) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal( + config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty} +) where {RT,Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -188,8 +209,9 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) end @@ -200,8 +222,9 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} +@inline function accumulate_into( + into::RT, seen::IdDict, from::RT +)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -216,9 +239,11 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into( + into::RT, seen::IdDict, from::RT +)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into+from, RT(0)) + seen[into] = (into + from, RT(0)) end return seen[into] end @@ -233,7 +258,9 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse( + config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty} +) where {RT,Ty} if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else @@ -245,43 +272,80 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{ return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) +@inline function pmap_fwd( + idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + }() + fwd_thunk, rev_thunk = autodiff_thunk( + config2, BodyTy, Const, typeof(count), map(typeof, args)... + ) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) +@inline function pmap_rev( + idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) +@inline function pmap_rev( + idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.reverse( + config, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + tapes, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + }() + fwd_thunk, rev_thunk = autodiff_thunk( + config2, BodyTy, Const, typeof(count), map(typeof, args)... + ) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -291,16 +355,14 @@ function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Co Libc.free(tapes) end - return ntuple(Val(2+length(args))) do _ + return ntuple(Val(2 + length(args))) do _ Base.@_inline_meta nothing end end - - # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -323,8 +385,9 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} - +function EnzymeRules.augmented_primal( + config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT} +) where {RT,AT<:Array,BT<:Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -362,33 +425,42 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end -@static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT}, - LinearAlgebra.QRCompactWY{eltype(AT), AT} - } -else - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} - } -end - - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( - (cache_res, dres, cache_A, cache_b) + @static if VERSION < v"1.8.0" + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT}, + LinearAlgebra.QRCompactWY{eltype(AT),AT}, + } + else + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT),AT,BT,Vector{Int}}, + } + end + + cache = NamedTuple{ + (Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), + Tuple{typeof(res),typeof(dres),UT,typeof(cache_b)}, + }((cache_res, dres, cache_A, cache_b)) + + return EnzymeRules.AugmentedReturn{typeof(retres),typeof(dres),typeof(cache)}( + retres, dres, cache ) - - return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT - +function EnzymeRules.reverse( + config, + func::Const{typeof(\)}, + ::Type{RT}, + cache, + A::Annotation{<:Array}, + b::Annotation{<:Array}, +) where {RT} y, dys, cache_A, cache_b = cache if !EnzymeRules.overwritten(config)[3] @@ -444,14 +516,11 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, dy .= eltype(dy)(0) end - return (nothing,nothing) + return (nothing, nothing) end const EnzymeTriangulars = Union{ - UpperTriangular, - LowerTriangular, - UnitUpperTriangular, - UnitLowerTriangular + UpperTriangular,LowerTriangular,UnitUpperTriangular,UnitLowerTriangular } function EnzymeRules.augmented_primal( @@ -460,8 +529,8 @@ function EnzymeRules.augmented_primal( ::Type{RT}, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {RT,YT<:Array,AT<:EnzymeTriangulars,BT<:Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -469,8 +538,9 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( - primal, shadow, (cache_Y, cache_A, cache_B)) + return EnzymeRules.AugmentedReturn{typeof(primal),typeof(shadow),Any}( + primal, shadow, (cache_Y, cache_A, cache_B) + ) end function EnzymeRules.reverse( @@ -480,8 +550,8 @@ function EnzymeRules.reverse( cache, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {YT<:Array,RT,AT<:EnzymeTriangulars,BT<:Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache for b in 1:EnzymeRules.width(config) @@ -507,62 +577,75 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" -# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - primal = if EnzymeRules.needs_primal(config) - out.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) - out.dval - else - nothing - end - func.val(out.val, inp.val) - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - nr, nc = size(out.val,1), size(out.val,2) - for b in 1:EnzymeRules.width(config) - da = if EnzymeRules.width(config) == 1 + # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) + function EnzymeRules.augmented_primal( + config, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + out::Annotation{AT}, + inp::Annotation{BT}, + ) where {RT,AT<:Array,BT<:Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) out.dval else - out.dval[b] + nothing end - i = 1 - j = 1 - if (typeof(inp) <: Active) - dinp = ntuple(Val(length(inp.val))) do k - Base.@_inline_meta - res = da[i, j] - da[i, j] = 0 - j += 1 - if j == nc+1 - i += 1 - j = 1 - end - T = BT.parameters[k] - if T <: AbstractFloat - T(res) - else - T(0) + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) + end + + function EnzymeRules.reverse( + config, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + _, + out::Annotation{AT}, + inp::Annotation{BT}, + ) where {RT,AT<:Array,BT<:Tuple} + nr, nc = size(out.val, 1), size(out.val, 2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 + out.dval + else + out.dval[b] + end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc + 1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) + end end + return (nothing, dinp)::Tuple{Nothing,BT} end - return (nothing, dinp)::Tuple{Nothing, BT} end + return (nothing, nothing) end - return (nothing, nothing) -end end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -576,11 +659,11 @@ function EnzymeRules.forward( end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -595,14 +678,13 @@ function EnzymeRules.forward( end end - function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -620,13 +702,13 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] @@ -694,11 +776,7 @@ end # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, - RT::Type, - fact::Annotation{<:Cholesky}, - B; - kwargs... + func::Const{typeof(ldiv!)}, RT::Type, fact::Annotation{<:Cholesky}, B; kwargs... ) if isa(B, Const) @assert (RT <: Const) @@ -708,11 +786,15 @@ function EnzymeRules.forward( @assert !isa(B, Const) - retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) - func.val(fact.val, B.val; kwargs...) - else - nothing - end + retval = + if !isa(fact, Const) || + (RT <: Const) || + (RT <: Duplicated) || + (RT <: BatchDuplicated) + func.val(fact.val, B.val; kwargs...) + else + nothing + end dretvals = ntuple(Val(N)) do b Base.@_inline_meta @@ -724,13 +806,12 @@ function EnzymeRules.forward( end if !isa(fact, Const) - dfact = if N == 1 fact.dval else fact.dval[b] end - + tmp = dfact.U * retval mul!(dB, dfact.L, tmp, -1, 1) end @@ -757,8 +838,8 @@ function EnzymeRules.augmented_primal( func::Const{typeof(cholesky)}, RT::Type, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - + kwargs..., +) fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -791,7 +872,8 @@ function EnzymeRules.reverse( RT::Type, cache, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) + kwargs..., +) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval @@ -845,13 +927,14 @@ function _realifydiag!(A) end function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... + config, + func::Const{typeof(ldiv!)}, + RT::Type{ + <:Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated} + }, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs..., ) cache_B = if !isa(A, Const) && !isa(B, Const) EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val @@ -877,10 +960,10 @@ function EnzymeRules.reverse( dret, cache, A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... + B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs..., ) - if !isa(B, Const) + if !isa(B, Const) (cache_A, cache_B) = cache Y = B.val U = cache_A.U diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 9cf2de03fd..a2e7de178c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -17,7 +17,7 @@ function sorterrfn(t, x) function lt(a, b) return a.a < b.a end - return first(sortperm(t, lt=lt)) * x + return first(sortperm(t; lt=lt)) * x end @testset "Sort rules" begin @@ -28,10 +28,12 @@ end end @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == + (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -41,10 +43,13 @@ end end @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 - dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + dd = Duplicated( + [TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)] + ) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) @test res[1][2] ≈ 3 @@ -62,7 +67,13 @@ end b = Float64[11, 13] db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Duplicated{typeof(A)}, + Duplicated{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) @@ -79,7 +90,13 @@ end db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Const{typeof(A)}, + Duplicated{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) @@ -95,7 +112,13 @@ end dA = zero(A) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Duplicated{typeof(A)}, + Const{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) @@ -111,88 +134,98 @@ end end @static if VERSION > v"1.8" -@testset "Cholesky" begin - function cholesky_testfunction_symmetric(A, b, x1, x2) - C1 = cholesky(A * A') # test factorization without wrapper - C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - function cholesky_testfunction_hermitian(A, b, x1, x2) - C1 = cholesky(A * adjoint(A)) # test factorization without wrapper - C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - @testset for (TE, testfunction) in ( - Float64 => cholesky_testfunction_symmetric, - Float64 => cholesky_testfunction_hermitian - ) - @testset for TA in (Const, Duplicated), - Tb in (Const, Duplicated), - Tx1 in (Const, Duplicated), - Tx2 in (Const, Duplicated) - A = rand(TE, 5, 5) - b = rand(TE, 5) - x1 = rand(TE, 5) - x2 = rand(TE, 5) - # ishermitian(A * adjoint(A)) || continue - @testset for Tret in (Const, Duplicated) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) - end - @testset for Tret in (Const, Active) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + @testset "Cholesky" begin + function cholesky_testfunction_symmetric(A, b, x1, x2) + C1 = cholesky(A * A') # test factorization without wrapper + C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + function cholesky_testfunction_hermitian(A, b, x1, x2) + C1 = cholesky(A * adjoint(A)) # test factorization without wrapper + C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + @testset for (TE, testfunction) in ( + Float64 => cholesky_testfunction_symmetric, + Float64 => cholesky_testfunction_hermitian, + ) + @testset for TA in (Const, Duplicated), + Tb in (Const, Duplicated), + Tx1 in (Const, Duplicated), + Tx2 in (Const, Duplicated) + + A = rand(TE, 5, 5) + b = rand(TE, 5) + x1 = rand(TE, 5) + x2 = rand(TE, 5) + # ishermitian(A * adjoint(A)) || continue + @testset for Tret in (Const, Duplicated) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + @testset for Tret in (Const, Active) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end end end end -end -@testset "Linear solve for triangular matrices" begin - @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), - TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - function f!(Y, A, B, ::T) where T - ldiv!(Y, T(A), B) - return nothing + @testset "Linear solve for triangular matrices" begin + @testset for T in ( + UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular + ), + TE in (Float64, ComplexF64), + sizeB in ((3,), (3, 3)) + + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where {T} + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + end end - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff( + Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1) + ) + autodiff( + Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2) + ) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 end end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing - end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) - autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 - end end end -end end # InternalRules From 2977be4e83f6b7642a198da6e091fc018b99f1ad Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 21 Apr 2024 11:11:55 +0200 Subject: [PATCH 30/61] Fix typo --- test/internal_rules.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 350c974e6b..ee1f76660a 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -230,8 +230,6 @@ end end end -end - @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform struct MyDistribution From ae2a0331ff3ccba7e9c19cfc51781dd4da4ddc73 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 00:51:00 +0200 Subject: [PATCH 31/61] Add tests suggested in review --- test/internal_rules.jl | 84 +++++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index bfaa40cdc3..3b6877a2fa 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -158,41 +158,57 @@ end @static if VERSION > v"1.8" @testset "Cholesky" begin - function cholesky_testfunction_symmetric(A, b, x1, x2) - C1 = cholesky(A * A') # test factorization without wrapper - C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - function cholesky_testfunction_hermitian(A, b, x1, x2) - C1 = cholesky(A * adjoint(A)) # test factorization without wrapper - C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - @testset for (TE, testfunction) in ( - Float64 => cholesky_testfunction_symmetric, - Float64 => cholesky_testfunction_hermitian, - ) - @testset for TA in (Const, Duplicated), - Tb in (Const, Duplicated), - Tx1 in (Const, Duplicated), - Tx2 in (Const, Duplicated) - - A = rand(TE, 5, 5) - b = rand(TE, 5) - x1 = rand(TE, 5) - x2 = rand(TE, 5) - # ishermitian(A * adjoint(A)) || continue - @testset for Tret in (Const, Duplicated) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + @testset "EnzymeTestUtils tests" begin + @testset "cholesky" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + + A = exp(TS(rand(Te, 4, 4))) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) + end end - @testset for Tret in (Const, Active) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + end + + @testset "Other tests" begin + function cholesky_testfunction_symmetric(A, b, x1, x2) + C1 = cholesky(A * A') # test factorization without wrapper + C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + function cholesky_testfunction_hermitian(A, b, x1, x2) + C1 = cholesky(A * adjoint(A)) # test factorization without wrapper + C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + @testset for (TE, testfunction) in ( + Float64 => cholesky_testfunction_symmetric, + Float64 => cholesky_testfunction_hermitian, + ) + @testset for TA in (Const, Duplicated), + Tb in (Const, Duplicated), + Tx1 in (Const, Duplicated), + Tx2 in (Const, Duplicated) + + A = rand(TE, 5, 5) + b = rand(TE, 5) + x1 = rand(TE, 5) + x2 = rand(TE, 5) + # ishermitian(A * adjoint(A)) || continue + @testset for Tret in (Const, Duplicated) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + @testset for Tret in (Const, Active) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end end end end From 3c1784e7d8828c9379bd130b87954f81c417473b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 00:57:42 +0200 Subject: [PATCH 32/61] Fix rrule --- src/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 09d01535c5..5ed230268f 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -993,7 +993,7 @@ function EnzymeRules.reverse( Ā = _cholesky_pullback_shared_code(fact, dfact) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - _dA .+= UpperTriangular(Ā) + _dA .+= UpperTriangular(Ā) .+ UpperTriangular(tril!(dfact.factors, -1)') dfact.factors .= 0 end end From b4c17d258cc71dd5ef6dbf8580a58fedd37e7bdd Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 01:08:29 +0200 Subject: [PATCH 33/61] Fix forward rules --- src/internal_rules.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 5ed230268f..b27697ae12 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -865,6 +865,7 @@ function _cholesky_forward(C::Cholesky, Ȧ) U̇[idx] ./= 2 triu!(U̇) rmul!(U̇, U) + U̇ .+= UpperTriangular(Ȧ)' - Diagonal(Ȧ) # correction for unused triangle return Cholesky(U̇, 'U', C.info) else L = C.L @@ -874,6 +875,7 @@ function _cholesky_forward(C::Cholesky, Ȧ) L̇[idx] ./= 2 tril!(L̇) lmul!(L, L̇) + L̇ .+= LowerTriangular(Ȧ)' - Diagonal(Ȧ) # correction for unused triangle return Cholesky(L̇, 'L', C.info) end end @@ -993,7 +995,7 @@ function EnzymeRules.reverse( Ā = _cholesky_pullback_shared_code(fact, dfact) idx = diagind(Ā) @views Ā[idx] .= real.(Ā[idx]) ./ 2 - _dA .+= UpperTriangular(Ā) .+ UpperTriangular(tril!(dfact.factors, -1)') + _dA .+= UpperTriangular(Ā) dfact.factors .= 0 end end @@ -1015,6 +1017,7 @@ function _cholesky_pullback_shared_code(C, ΔC) eltype(Ā) <: Real || _realifydiag!(Ā) ldiv!(U, Ā) rdiv!(Ā, U') + Ā .+= tril!(ΔC.factors, -1)' # correction for unused triangle else # C.uplo === 'L' L = C.L L̄ = ΔC.L @@ -1023,6 +1026,7 @@ function _cholesky_pullback_shared_code(C, ΔC) eltype(Ā) <: Real || _realifydiag!(Ā) rdiv!(Ā, L) ldiv!(L', Ā) + Ā .+= triu!(ΔC.factors, 1)' # correction for unused triangle end return Ā end From 7e5173345f77ed1cd3bc40864552c1ce3d450a4b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 01:36:19 +0200 Subject: [PATCH 34/61] Increase test coverage, remove old tests --- test/internal_rules.jl | 80 +++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 3b6877a2fa..e7d961eb11 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -157,58 +157,42 @@ end end @static if VERSION > v"1.8" - @testset "Cholesky" begin - @testset "EnzymeTestUtils tests" begin - @testset "cholesky" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - - A = exp(TS(rand(Te, 4, 4))) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) - end + @testset "cholesky" begin + @testset "with wrapper arguments" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = TS(exp(rand(Te, 4, 4)), uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) end end end - - @testset "Other tests" begin - function cholesky_testfunction_symmetric(A, b, x1, x2) - C1 = cholesky(A * A') # test factorization without wrapper - C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - function cholesky_testfunction_hermitian(A, b, x1, x2) - C1 = cholesky(A * adjoint(A)) # test factorization without wrapper - C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + @testset "without wrapper arguments" begin + _square(A) = A * A' + @testset for Te in (Float64,) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = rand(Te, 4, 4) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) + end end - @testset for (TE, testfunction) in ( - Float64 => cholesky_testfunction_symmetric, - Float64 => cholesky_testfunction_hermitian, - ) - @testset for TA in (Const, Duplicated), - Tb in (Const, Duplicated), - Tx1 in (Const, Duplicated), - Tx2 in (Const, Duplicated) - - A = rand(TE, 5, 5) - b = rand(TE, 5) - x1 = rand(TE, 5) - x2 = rand(TE, 5) - # ishermitian(A * adjoint(A)) || continue - @testset for Tret in (Const, Duplicated) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) - end - @testset for Tret in (Const, Active) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) - end + end + end + + @testset "Linear solve for `Cholesky`" begin + @testset for Te in (Float64,) + A = exp(Symmetric(rand(Te, 4, 4))) + C = cholesky(A) + B = rand(Te, 4, 4) + b = rand(Te, 4) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) end end end From ea098bee10ddc60ef9e8c8b85bea04b88ba15059 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 01:36:30 +0200 Subject: [PATCH 35/61] Fix additional tests --- src/internal_rules.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b27697ae12..3a45a4c8d4 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -993,9 +993,7 @@ function EnzymeRules.reverse( _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA if _dA !== dfact.factors Ā = _cholesky_pullback_shared_code(fact, dfact) - idx = diagind(Ā) - @views Ā[idx] .= real.(Ā[idx]) ./ 2 - _dA .+= UpperTriangular(Ā) + _dA .+= Ā dfact.factors .= 0 end end @@ -1018,6 +1016,7 @@ function _cholesky_pullback_shared_code(C, ΔC) ldiv!(U, Ā) rdiv!(Ā, U') Ā .+= tril!(ΔC.factors, -1)' # correction for unused triangle + triu!(Ā) else # C.uplo === 'L' L = C.L L̄ = ΔC.L @@ -1027,7 +1026,10 @@ function _cholesky_pullback_shared_code(C, ΔC) rdiv!(Ā, L) ldiv!(L', Ā) Ā .+= triu!(ΔC.factors, 1)' # correction for unused triangle + tril!(Ā) end + idx = diagind(Ā) + @views Ā[idx] .= real.(Ā[idx]) ./ 2 return Ā end From 4066fd1d02d3497b7d67e8eac76ba1bd0c165982 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 02:00:36 +0200 Subject: [PATCH 36/61] Try to fix positive definiteness issues in CI --- test/internal_rules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index e7d961eb11..16e4876a6e 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -161,7 +161,8 @@ end @testset "with wrapper arguments" begin @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = TS(exp(rand(Te, 4, 4)), uplo) + _A = collect(exp(TS(rand(Te, 4, 4)))) + A = TS(_A, uplo) are_activities_compatible(Tret, TA) || continue test_forward(cholesky, Tret, (A, TA)) test_reverse(cholesky, Tret, (A, TA)) From 22c533a0501630288f98a7f34f17277b013af99e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 28 Apr 2024 22:28:50 +0200 Subject: [PATCH 37/61] Revert "Run formatter" This reverts commit 188949189853b361db258daea95748893515e4f3. --- src/internal_rules.jl | 427 +++++++++++++++++------------------------ test/internal_rules.jl | 186 ++++++++---------- 2 files changed, 250 insertions(+), 363 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 3a45a4c8d4..02330f2e38 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -102,9 +102,7 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl( - ::typeof(Base.setindex!), ::IdDict{K,V}, ::K, ::V -) where {K,V<:Integer} +function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} return nothing end @@ -122,45 +120,35 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true @inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T,N}) where {T,N} = N +@inline width(::BatchDuplicated{T, N}) where {T, N} = N @inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N +@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N -@inline width(::Type{Duplicated{T}}) where {T} = 1 -@inline width(::Type{BatchDuplicated{T,N}}) where {T,N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where {T} = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N +@inline width(::Type{Duplicated{T}}) where T = 1 +@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward( - ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated -) +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return deepcopy(x.dval) end -function EnzymeRules.forward( - ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T,N} -) where {T,N} +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} ntuple(Val(N)) do _ deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact( - copied::RT, primal::RT, seen::IdDict, shadow::RT -) where {RT<:Union{Integer,Char}} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact( - copied::RT, primal::RT, seen::IdDict, shadow::RT -) where {RT<:AbstractFloat} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact( - copied::RT, primal::RT, seen::IdDict, shadow::RT -) where {RT<:Array} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -174,28 +162,19 @@ end return seen[shadow] end -function EnzymeRules.forward( - func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated -) +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward( - func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T,N} -) where {T,N} +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} primal = func.val(x.val) - return BatchDuplicated( - primal, - ntuple(Val(N)) do i - deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) - end, - ) + return BatchDuplicated(primal, ntuple(Val(N)) do i + deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) + end) end -function EnzymeRules.augmented_primal( - config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty} -) where {RT,Ty} +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -212,9 +191,8 @@ function EnzymeRules.augmented_primal( shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero( - source, - Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# + Enzyme.make_zero(source, + #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) ) end @@ -225,9 +203,8 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end -@inline function accumulate_into( - into::RT, seen::IdDict, from::RT -)::Tuple{RT,RT} where {RT<:Array} + +@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -242,11 +219,9 @@ end return seen[into] end -@inline function accumulate_into( - into::RT, seen::IdDict, from::RT -)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into + from, RT(0)) + seen[into] = (into+from, RT(0)) end return seen[into] end @@ -261,9 +236,7 @@ end return seen[into] end -function EnzymeRules.reverse( - config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty} -) where {RT,Ty} +function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else @@ -275,80 +248,43 @@ function EnzymeRules.reverse( return (nothing,) end -@inline function pmap_fwd( - idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} -) where {ThunkTy,F,N} +@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd( - idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} -) where {ThunkTy,F,N} - return unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) +@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(Enzyme.pmap)}, - ::Type{Const{Nothing}}, - body::BodyTy, - count, - args::Vararg{Annotation,N}, -) where {BodyTy,N} - config2 = ReverseModeSplit{ - false, - false, - EnzymeRules.width(config), - EnzymeRules.overwritten(config)[2:end], - InlineABI, - }() - fwd_thunk, rev_thunk = autodiff_thunk( - config2, BodyTy, Const, typeof(count), map(typeof, args)... - ) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} + + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev( - idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} -) where {ThunkTy,F,N} - return thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) +@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev( - idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} -) where {ThunkTy,F,N} - return thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) +@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse( - config, - func::Const{typeof(Enzyme.pmap)}, - ::Type{Const{Nothing}}, - tapes, - body::BodyTy, - count, - args::Vararg{Annotation,N}, -) where {BodyTy,N} - config2 = ReverseModeSplit{ - false, - false, - EnzymeRules.width(config), - EnzymeRules.overwritten(config)[2:end], - InlineABI, - }() - fwd_thunk, rev_thunk = autodiff_thunk( - config2, BodyTy, Const, typeof(count), map(typeof, args)... - ) +function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} + + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -358,14 +294,16 @@ function EnzymeRules.reverse( Libc.free(tapes) end - return ntuple(Val(2 + length(args))) do _ + return ntuple(Val(2+length(args))) do _ Base.@_inline_meta nothing end end + + # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -388,9 +326,8 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal( - config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT} -) where {RT,AT<:Array,BT<:Array} +function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} + cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -428,42 +365,33 @@ function EnzymeRules.augmented_primal( nothing end - @static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT),BT}, - LinearAlgebra.LowerTriangular{eltype(AT),AT}, - LinearAlgebra.UpperTriangular{eltype(AT),AT}, - LinearAlgebra.LU{eltype(AT),AT}, - LinearAlgebra.QRCompactWY{eltype(AT),AT}, - } - else - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT),BT}, - LinearAlgebra.LowerTriangular{eltype(AT),AT}, - LinearAlgebra.UpperTriangular{eltype(AT),AT}, - LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT),AT,BT,Vector{Int}}, - } - end - - cache = NamedTuple{ - (Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), - Tuple{typeof(res),typeof(dres),UT,typeof(cache_b)}, - }((cache_res, dres, cache_A, cache_b)) - - return EnzymeRules.AugmentedReturn{typeof(retres),typeof(dres),typeof(cache)}( - retres, dres, cache +@static if VERSION < v"1.8.0" + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.LowerTriangular{eltype(AT), AT}, + LinearAlgebra.UpperTriangular{eltype(AT), AT}, + LinearAlgebra.LU{eltype(AT), AT}, + LinearAlgebra.QRCompactWY{eltype(AT), AT} + } +else + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.LowerTriangular{eltype(AT), AT}, + LinearAlgebra.UpperTriangular{eltype(AT), AT}, + LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} + } +end + + cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( + (cache_res, dres, cache_A, cache_b) ) + + return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) end -function EnzymeRules.reverse( - config, - func::Const{typeof(\)}, - ::Type{RT}, - cache, - A::Annotation{<:Array}, - b::Annotation{<:Array}, -) where {RT} +function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT + y, dys, cache_A, cache_b = cache if !EnzymeRules.overwritten(config)[3] @@ -519,11 +447,14 @@ function EnzymeRules.reverse( dy .= eltype(dy)(0) end - return (nothing, nothing) + return (nothing,nothing) end const EnzymeTriangulars = Union{ - UpperTriangular,LowerTriangular,UnitUpperTriangular,UnitLowerTriangular + UpperTriangular, + LowerTriangular, + UnitUpperTriangular, + UnitLowerTriangular } function EnzymeRules.augmented_primal( @@ -532,8 +463,8 @@ function EnzymeRules.augmented_primal( ::Type{RT}, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT}, -) where {RT,YT<:Array,AT<:EnzymeTriangulars,BT<:Array} + B::Annotation{BT} +) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -541,9 +472,8 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal),typeof(shadow),Any}( - primal, shadow, (cache_Y, cache_A, cache_B) - ) + return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( + primal, shadow, (cache_Y, cache_A, cache_B)) end function EnzymeRules.reverse( @@ -553,8 +483,8 @@ function EnzymeRules.reverse( cache, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT}, -) where {YT<:Array,RT,AT<:EnzymeTriangulars,BT<:Array} + B::Annotation{BT} +) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache for b in 1:EnzymeRules.width(config) @@ -580,75 +510,62 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" - # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) - function EnzymeRules.augmented_primal( - config, - func::Const{typeof(Base.hvcat_fill!)}, - ::Type{RT}, - out::Annotation{AT}, - inp::Annotation{BT}, - ) where {RT,AT<:Array,BT<:Tuple} - primal = if EnzymeRules.needs_primal(config) - out.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) +# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + out.dval + else + nothing + end + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + nr, nc = size(out.val,1), size(out.val,2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 out.dval else - nothing + out.dval[b] end - func.val(out.val, inp.val) - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) - end - - function EnzymeRules.reverse( - config, - func::Const{typeof(Base.hvcat_fill!)}, - ::Type{RT}, - _, - out::Annotation{AT}, - inp::Annotation{BT}, - ) where {RT,AT<:Array,BT<:Tuple} - nr, nc = size(out.val, 1), size(out.val, 2) - for b in 1:EnzymeRules.width(config) - da = if EnzymeRules.width(config) == 1 - out.dval - else - out.dval[b] - end - i = 1 - j = 1 - if (typeof(inp) <: Active) - dinp = ntuple(Val(length(inp.val))) do k - Base.@_inline_meta - res = da[i, j] - da[i, j] = 0 - j += 1 - if j == nc + 1 - i += 1 - j = 1 - end - T = BT.parameters[k] - if T <: AbstractFloat - T(res) - else - T(0) - end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc+1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) end - return (nothing, dinp)::Tuple{Nothing,BT} end + return (nothing, dinp)::Tuple{Nothing, BT} end - return (nothing, nothing) end + return (nothing, nothing) +end end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - xs::Duplicated{T}; - kwargs..., -) where {T<:AbstractArray{<:AbstractFloat}} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -662,11 +579,11 @@ function EnzymeRules.forward( end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, - xs::BatchDuplicated{T,N}; - kwargs..., -) where {T<:AbstractArray{<:AbstractFloat},N} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -681,13 +598,14 @@ function EnzymeRules.forward( end end + function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - xs::Duplicated{T}; - kwargs..., -) where {T<:AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -705,13 +623,13 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., -) where {T<:AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] @@ -886,7 +804,11 @@ end # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, RT::Type, fact::Annotation{<:Cholesky}, B; kwargs... + func::Const{typeof(ldiv!)}, + RT::Type, + fact::Annotation{<:Cholesky}, + B; + kwargs... ) if isa(B, Const) @assert (RT <: Const) @@ -896,15 +818,11 @@ function EnzymeRules.forward( @assert !isa(B, Const) - retval = - if !isa(fact, Const) || - (RT <: Const) || - (RT <: Duplicated) || - (RT <: BatchDuplicated) - func.val(fact.val, B.val; kwargs...) - else - nothing - end + retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) + func.val(fact.val, B.val; kwargs...) + else + nothing + end dretvals = ntuple(Val(N)) do b Base.@_inline_meta @@ -916,12 +834,13 @@ function EnzymeRules.forward( end if !isa(fact, Const) + dfact = if N == 1 fact.dval else fact.dval[b] end - + tmp = dfact.U * retval mul!(dB, dfact.L, tmp, -1, 1) end @@ -948,8 +867,8 @@ function EnzymeRules.augmented_primal( func::Const{typeof(cholesky)}, RT::Type, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs..., -) + kwargs...) + fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -982,8 +901,7 @@ function EnzymeRules.reverse( RT::Type, cache, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs..., -) + kwargs...) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval @@ -1041,14 +959,13 @@ function _realifydiag!(A) end function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{ - <:Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated} - }, - A::Annotation{<:Cholesky}, - B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; - kwargs..., + config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, + + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... ) cache_B = if !isa(A, Const) && !isa(B, Const) EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val @@ -1074,10 +991,10 @@ function EnzymeRules.reverse( dret, cache, A::Annotation{<:Cholesky}, - B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; - kwargs..., + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... ) - if !isa(B, Const) + if !isa(B, Const) (cache_A, cache_B) = cache Y = B.val U = cache_A.U diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 16e4876a6e..6598bfda67 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -18,7 +18,7 @@ function sorterrfn(t, x) function lt(a, b) return a.a < b.a end - return first(sortperm(t; lt=lt)) * x + return first(sortperm(t, lt=lt)) * x end @testset "Sort rules" begin @@ -29,12 +29,10 @@ end end @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == - (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == - (var"1"=0.0, var"2"=0.0) + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -44,8 +42,7 @@ end end @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == - (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 function f3(x) @@ -90,13 +87,7 @@ end b = Float64[11, 13] db = zero(b) - forward, pullback = Enzyme.autodiff_thunk( - ReverseSplitNoPrimal, - Const{typeof(\)}, - Duplicated, - Duplicated{typeof(A)}, - Duplicated{typeof(b)}, - ) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) @@ -113,13 +104,7 @@ end db = zero(b) - forward, pullback = Enzyme.autodiff_thunk( - ReverseSplitNoPrimal, - Const{typeof(\)}, - Duplicated, - Const{typeof(A)}, - Duplicated{typeof(b)}, - ) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) @@ -135,13 +120,7 @@ end dA = zero(A) - forward, pullback = Enzyme.autodiff_thunk( - ReverseSplitNoPrimal, - Const{typeof(\)}, - Duplicated, - Duplicated{typeof(A)}, - Const{typeof(b)}, - ) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) @@ -157,101 +136,92 @@ end end @static if VERSION > v"1.8" - @testset "cholesky" begin - @testset "with wrapper arguments" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - _A = collect(exp(TS(rand(Te, 4, 4)))) - A = TS(_A, uplo) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) - end +@testset "cholesky" begin + @testset "with wrapper arguments" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + _A = collect(exp(TS(rand(Te, 4, 4)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) end end - @testset "without wrapper arguments" begin - _square(A) = A * A' - @testset for Te in (Float64,) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = rand(Te, 4, 4) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) - end + end + @testset "without wrapper arguments" begin + _square(A) = A * A' + @testset for Te in (Float64,) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = rand(Te, 4, 4) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) end end end +end - @testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,) - A = exp(Symmetric(rand(Te, 4, 4))) - C = cholesky(A) - B = rand(Te, 4, 4) - b = rand(Te, 4) - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - test_forward(\, Tret, (C, TC), (_B, TB)) - test_reverse(\, Tret, (C, TC), (_B, TB)) - end +@testset "Linear solve for `Cholesky`" begin + @testset for Te in (Float64,) + A = exp(Symmetric(rand(Te, 4, 4))) + C = cholesky(A) + B = rand(Te, 4, 4) + b = rand(Te, 4) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) end end end +end - @testset "Linear solve for triangular matrices" begin - @testset for T in ( - UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular - ), - TE in (Float64, ComplexF64), - sizeB in ((3,), (3, 3)) - - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - function f!(Y, A, B, ::T) where {T} - ldiv!(Y, T(A), B) - return nothing - end - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) - end +@testset "Linear solve for triangular matrices" begin + @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), + TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where T + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing - end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff( - Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1) - ) - autodiff( - Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2) - ) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 + end + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 end end end +end @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform From c348c610e53aa165c88ea035f4acf0f05ac9d519 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 30 Apr 2024 23:03:02 +0200 Subject: [PATCH 38/61] Change forward rule --- src/internal_rules.jl | 50 ++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 02330f2e38..fdf77cb134 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -798,24 +798,19 @@ function _cholesky_forward(C::Cholesky, Ȧ) end end -# y = inv(A) B -# dY = inv(A) [ dB - dA y ] -# -> -# B(out) = inv(A) B(in) -# dB(out) = inv(A) [ dB(in) - dA B(out) ] function EnzymeRules.forward( func::Const{typeof(ldiv!)}, - RT::Type, + RT::Type{<:Union{Const, Duplicated}}, fact::Annotation{<:Cholesky}, - B; + B::Annotation{<:AbstractVecOrMat}; kwargs... ) + @info "Hi from forward(::typeof(ldiv!),...)" if isa(B, Const) - @assert (RT <: Const) return func.val(fact.val, B.val; kwargs...) else - N = width(B) + N = width(B) @assert !isa(B, Const) retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) @@ -826,26 +821,17 @@ function EnzymeRules.forward( dretvals = ntuple(Val(N)) do b Base.@_inline_meta - - dB = if N == 1 - B.dval - else - B.dval[b] - end - + dB = N == 1 ? B.dval : B.dval[b] if !isa(fact, Const) - - dfact = if N == 1 - fact.dval - else - fact.dval[b] - end - - tmp = dfact.U * retval - mul!(dB, dfact.L, tmp, -1, 1) + dfact = N == 1 ? fact.dval : fact.dval[b] + L = fact.val.L + U = fact.val.U + dL = dfact.L + dU = dfact.U + _ldiv_Cholesky_forward!(L, U, B.val, dL, dU, dB) end - - func.val(fact.val, dB; kwargs...) + return dB + # func.val(fact.val, dB; kwargs...) end if RT <: Const @@ -862,6 +848,16 @@ function EnzymeRules.forward( end end +function _ldiv_Cholesky_forward!(L, U, B, dL, dU, dB) + ldiv!(L, B) + mul!(dB, dL, B, -1, 1) + ldiv!(L, dB) + ldiv!(U, B) + mul!(dB, dU, B, -1, 1) + ldiv!(U, dB) + return B, dB +end + function EnzymeRules.augmented_primal( config, func::Const{typeof(cholesky)}, From 74e218399eba80e90e06156ead65fb9e01754e25 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 30 Apr 2024 23:20:45 +0200 Subject: [PATCH 39/61] Fix `Duplicated` case --- src/internal_rules.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index fdf77cb134..789c239d52 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -814,7 +814,7 @@ function EnzymeRules.forward( @assert !isa(B, Const) retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) - func.val(fact.val, B.val; kwargs...) + B.val else nothing end @@ -831,7 +831,6 @@ function EnzymeRules.forward( _ldiv_Cholesky_forward!(L, U, B.val, dL, dU, dB) end return dB - # func.val(fact.val, dB; kwargs...) end if RT <: Const From e5f85e0f02db60871bc9a3164d0134cf0147b275 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Wed, 1 May 2024 12:13:30 +0200 Subject: [PATCH 40/61] Slightly refactor forward rule --- src/internal_rules.jl | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 789c239d52..ad8c905700 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -783,7 +783,7 @@ function _cholesky_forward(C::Cholesky, Ȧ) U̇[idx] ./= 2 triu!(U̇) rmul!(U̇, U) - U̇ .+= UpperTriangular(Ȧ)' - Diagonal(Ȧ) # correction for unused triangle + U̇ .+= UpperTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle return Cholesky(U̇, 'U', C.info) else L = C.L @@ -793,7 +793,7 @@ function _cholesky_forward(C::Cholesky, Ȧ) L̇[idx] ./= 2 tril!(L̇) lmul!(L, L̇) - L̇ .+= LowerTriangular(Ȧ)' - Diagonal(Ȧ) # correction for unused triangle + L̇ .+= LowerTriangular(Ȧ)' .- Diagonal(Ȧ) # correction for unused triangle return Cholesky(L̇, 'L', C.info) end end @@ -805,24 +805,19 @@ function EnzymeRules.forward( B::Annotation{<:AbstractVecOrMat}; kwargs... ) - @info "Hi from forward(::typeof(ldiv!),...)" - if isa(B, Const) + if B isa Const return func.val(fact.val, B.val; kwargs...) else - N = width(B) - @assert !isa(B, Const) - - retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) - B.val - else - nothing - end + retval = B.val dretvals = ntuple(Val(N)) do b Base.@_inline_meta dB = N == 1 ? B.dval : B.dval[b] - if !isa(fact, Const) + if fact isa Const + ldiv!(fact.val, B.val) + ldiv!(fact.val, dB) + else dfact = N == 1 ? fact.dval : fact.dval[b] L = fact.val.L U = fact.val.U From 1e5481dd2eb64dcecb7a60da6f1063f15edc1c39 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 2 May 2024 22:41:44 +0200 Subject: [PATCH 41/61] Disable questionable tests, fix reverse rule for `ldiv!` --- src/internal_rules.jl | 4 ++-- test/internal_rules.jl | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index ad8c905700..d6e862ff6e 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -958,7 +958,7 @@ function EnzymeRules.augmented_primal( kwargs... ) cache_B = if !isa(A, Const) && !isa(B, Const) - EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val + copy(B.val) else nothing end @@ -986,9 +986,9 @@ function EnzymeRules.reverse( ) if !isa(B, Const) (cache_A, cache_B) = cache - Y = B.val U = cache_A.U Z = isa(A, Const) ? nothing : U' \ cache_B + Y = isa(A, Const) ? nothing : U \ Z for b in 1:EnzymeRules.width(config) dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b] dZ = U' \ dB diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 6598bfda67..7bae91b971 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -171,10 +171,20 @@ end Tret in (Const, Duplicated) @testset "$(size(_B))" for _B in (B, b) are_activities_compatible(Tret, TC, TB) || continue - test_forward(\, Tret, (C, TC), (_B, TB)) + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) test_reverse(\, Tret, (C, TC), (_B, TB)) end end + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + end + end end end From 3f8ee74dae64a28fff1003884cc24907cff09440 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 2 May 2024 22:45:10 +0200 Subject: [PATCH 42/61] Run formatter --- src/internal_rules.jl | 502 +++++++++++++++++++++-------------------- test/internal_rules.jl | 204 +++++++++-------- 2 files changed, 367 insertions(+), 339 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d6e862ff6e..29be9c93a8 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -69,7 +69,8 @@ end function EnzymeRules.inactive(::typeof(Random.rand), ::Random.AbstractRNG, ::Random.Sampler) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) +function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, + ::Random.Sampler, ::AbstractArray) return nothing end function EnzymeRules.inactive(::typeof(Random.randn), args...) @@ -102,7 +103,8 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} +function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K,V}, ::K, + ::V) where {K,V<:Integer} return nothing end @@ -120,35 +122,41 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true @inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T, N}) where {T, N} = N +@inline width(::BatchDuplicated{T,N}) where {T,N} = N @inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N +@inline width(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N -@inline width(::Type{Duplicated{T}}) where T = 1 -@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N +@inline width(::Type{Duplicated{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicated{T,N}}) where {T,N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, + x::Duplicated) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, + ::Type{<:BatchDuplicatedNoNeed}, + x::BatchDuplicated{T,N}) where {T,N} ntuple(Val(N)) do _ - deepcopy(x.dval) + return deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, + shadow::RT) where {RT<:Union{Integer,Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, + shadow::RT) where {RT<:AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, + shadow::RT) where {RT<:Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -162,19 +170,23 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, + x::Duplicated) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, + x::BatchDuplicated{T,N}) where {T,N} primal = func.val(x.val) - return BatchDuplicated(primal, ntuple(Val(N)) do i - deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) - end) + return BatchDuplicated(primal, + ntuple(Val(N)) do i + return deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) + end) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, + ::Type{RT}, x::Annotation{Ty}) where {RT,Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -191,9 +203,8 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) - ) + return Enzyme.make_zero(source, + Val(!EnzymeRules.needs_primal(config))) end if EnzymeRules.width(config) == 1 @@ -203,8 +214,8 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} +@inline function accumulate_into(into::RT, seen::IdDict, + from::RT)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -219,9 +230,10 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into(into::RT, seen::IdDict, + from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into+from, RT(0)) + seen[into] = (into + from, RT(0)) end return seen[into] end @@ -236,7 +248,8 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, + x::Annotation{Ty}) where {RT,Ty} if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else @@ -248,43 +261,53 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{ return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, + fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) +@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, + fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} + return unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, body::BodyTy, count, + args::Vararg{Annotation,N}) where {BodyTy,N} + config2 = ReverseModeSplit{false,false,EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), + map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) +@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, + fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) +@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, + fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, tapes, body::BodyTy, count, + args::Vararg{Annotation,N}) where {BodyTy,N} + config2 = ReverseModeSplit{false,false,EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), + map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -294,16 +317,14 @@ function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Co Libc.free(tapes) end - return ntuple(Val(2+length(args))) do _ + return ntuple(Val(2 + length(args))) do _ Base.@_inline_meta - nothing + return nothing end end - - # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -326,8 +347,9 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} - +function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, + A::Annotation{AT}, + b::Annotation{BT}) where {RT,AT<:Array,BT<:Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -343,7 +365,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - zero(res) + return zero(res) end end @@ -365,33 +387,32 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end -@static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT}, - LinearAlgebra.QRCompactWY{eltype(AT), AT} - } -else - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} - } -end + @static if VERSION < v"1.8.0" + UT = Union{LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT}, + LinearAlgebra.QRCompactWY{eltype(AT),AT}} + else + UT = Union{LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT),AT,BT,Vector{Int}}} + end - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( - (cache_res, dres, cache_A, cache_b) - ) + cache = NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), + Tuple{typeof(res),typeof(dres),UT,typeof(cache_b)}}((cache_res, dres, + cache_A, + cache_b)) - return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) + return EnzymeRules.AugmentedReturn{typeof(retres),typeof(dres),typeof(cache)}(retres, + dres, + cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT - +function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, + A::Annotation{<:Array}, b::Annotation{<:Array}) where {RT} y, dys, cache_A, cache_b = cache if !EnzymeRules.overwritten(config)[3] @@ -412,7 +433,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, if typeof(A) <: Const ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - nothing + return nothing end else A.dval @@ -429,7 +450,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, if typeof(b) <: Const ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - nothing + return nothing end else b.dval @@ -447,24 +468,22 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, dy .= eltype(dy)(0) end - return (nothing,nothing) + return (nothing, nothing) end -const EnzymeTriangulars = Union{ - UpperTriangular, - LowerTriangular, - UnitUpperTriangular, - UnitLowerTriangular -} +const EnzymeTriangulars = Union{UpperTriangular, + LowerTriangular, + UnitUpperTriangular, + UnitLowerTriangular} -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - ::Type{RT}, - Y::Annotation{YT}, - A::Annotation{AT}, - B::Annotation{BT} -) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} +function EnzymeRules.augmented_primal(config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT}) where {RT,YT<:Array, + AT<:EnzymeTriangulars, + BT<:Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -472,19 +491,19 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( - primal, shadow, (cache_Y, cache_A, cache_B)) -end - -function EnzymeRules.reverse( - config, - func::Const{typeof(ldiv!)}, - ::Type{RT}, - cache, - Y::Annotation{YT}, - A::Annotation{AT}, - B::Annotation{BT} -) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + return EnzymeRules.AugmentedReturn{typeof(primal),typeof(shadow),Any}(primal, shadow, + (cache_Y, cache_A, + cache_B)) +end + +function EnzymeRules.reverse(config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + cache, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT}) where {YT<:Array,RT,AT<:EnzymeTriangulars, + BT<:Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache for b in 1:EnzymeRules.width(config) @@ -510,62 +529,65 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" -# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - primal = if EnzymeRules.needs_primal(config) - out.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) - out.dval - else - nothing - end - func.val(out.val, inp.val) - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - nr, nc = size(out.val,1), size(out.val,2) - for b in 1:EnzymeRules.width(config) - da = if EnzymeRules.width(config) == 1 + # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) + function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, out::Annotation{AT}, + inp::Annotation{BT}) where {RT,AT<:Array, + BT<:Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) out.dval else - out.dval[b] + nothing end - i = 1 - j = 1 - if (typeof(inp) <: Active) - dinp = ntuple(Val(length(inp.val))) do k - Base.@_inline_meta - res = da[i, j] - da[i, j] = 0 - j += 1 - if j == nc+1 - i += 1 - j = 1 - end - T = BT.parameters[k] - if T <: AbstractFloat - T(res) - else - T(0) + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) + end + + function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, + _, out::Annotation{AT}, + inp::Annotation{BT}) where {RT,AT<:Array,BT<:Tuple} + nr, nc = size(out.val, 1), size(out.val, 2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 + out.dval + else + out.dval[b] + end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc + 1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) + end end + return (nothing, dinp)::Tuple{Nothing,BT} end - return (nothing, dinp)::Tuple{Nothing, BT} end + return (nothing, nothing) end - return (nothing, nothing) -end end -function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward(::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -578,12 +600,10 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward(::Const{typeof(sort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat},N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -598,14 +618,11 @@ function EnzymeRules.forward( end end - -function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -622,27 +639,23 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, inds) end -function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs...,) where {T<:AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] return (nothing,) end -function EnzymeRules.forward( - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward(::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -661,13 +674,11 @@ function EnzymeRules.forward( end end -function EnzymeRules.forward( - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} +function EnzymeRules.forward(::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat},N} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -692,14 +703,13 @@ function EnzymeRules.forward( end end -function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const,Active,DuplicatedNoNeed, + Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs...) where {T<:AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -718,15 +728,15 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, inds) end -function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(partialsort!)}, - dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, - tape, - xs::Duplicated{T}, - k::Const{<:Union{Integer, OrdinalRange}}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} +function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active, + Type{<:Union{Const,Active,DuplicatedNoNeed, + Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer,OrdinalRange}}; + kwargs...,) where {T<:AbstractArray{<:AbstractFloat}} inds = tape kv = k.val if dret isa Active @@ -751,7 +761,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta - zero(A.val) + return zero(A.val) end else N == 1 ? (A.dval,) : A.dval @@ -798,13 +808,11 @@ function _cholesky_forward(C::Cholesky, Ȧ) end end -function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, Duplicated}}, - fact::Annotation{<:Cholesky}, - B::Annotation{<:AbstractVecOrMat}; - kwargs... -) +function EnzymeRules.forward(func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,Duplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs...) if B isa Const return func.val(fact.val, B.val; kwargs...) else @@ -852,13 +860,13 @@ function _ldiv_Cholesky_forward!(L, U, B, dL, dU, dB) return B, dB end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(cholesky)}, - RT::Type, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - +function EnzymeRules.augmented_primal(config, + func::Const{typeof(cholesky)}, + RT::Type, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSym{<:Real, + <:Matrix}}}; + kwargs...) fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -876,7 +884,7 @@ function EnzymeRules.augmented_primal( else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - Enzyme.make_zero(fact) + return Enzyme.make_zero(fact) end end end @@ -885,13 +893,14 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache) end -function EnzymeRules.reverse( - config, - ::Const{typeof(cholesky)}, - RT::Type, - cache, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) +function EnzymeRules.reverse(config, + ::Const{typeof(cholesky)}, + RT::Type, + cache, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSym{<:Real, + <:Matrix}}}; + kwargs...) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval @@ -948,15 +957,15 @@ function _realifydiag!(A) return A end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) +function EnzymeRules.augmented_primal(config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed, + BatchDuplicated}}, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) cache_B = if !isa(A, Const) && !isa(B, Const) copy(B.val) else @@ -975,16 +984,15 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) end -function EnzymeRules.reverse( - config, - func::Const{typeof(ldiv!)}, - dret, - cache, - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) - if !isa(B, Const) +function EnzymeRules.reverse(config, + func::Const{typeof(ldiv!)}, + dret, + cache, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) + if !isa(B, Const) (cache_A, cache_B) = cache U = cache_A.U Z = isa(A, Const) ? nothing : U' \ cache_B diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7bae91b971..048db4f681 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -7,7 +7,7 @@ using FiniteDifferences using LinearAlgebra using SparseArrays using Test -import Random +using Random: Random struct TPair a::Float64 @@ -18,7 +18,7 @@ function sorterrfn(t, x) function lt(a, b) return a.a < b.a end - return first(sortperm(t, lt=lt)) * x + return first(sortperm(t; lt=lt)) * x end @testset "Sort rules" begin @@ -29,10 +29,12 @@ end end @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == + (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -42,7 +44,8 @@ end end @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 function f3(x) @@ -51,7 +54,8 @@ end end @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 - @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == + (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 @@ -63,13 +67,15 @@ end @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 @static if VERSION < v"1.7-" || VERSION >= v"1.8-" - @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == + (var"1"=1.5, var"2"=3.0) end @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 - dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], + [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) @test res[1][2] ≈ 3 @@ -87,7 +93,9 @@ end b = Float64[11, 13] db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, + Duplicated, Duplicated{typeof(A)}, + Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) @@ -104,7 +112,9 @@ end db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, + Duplicated, Const{typeof(A)}, + Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) @@ -120,7 +130,9 @@ end dA = zero(A) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, + Duplicated, Duplicated{typeof(A)}, + Const{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) @@ -136,102 +148,109 @@ end end @static if VERSION > v"1.8" -@testset "cholesky" begin - @testset "with wrapper arguments" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - _A = collect(exp(TS(rand(Te, 4, 4)))) - A = TS(_A, uplo) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) + @testset "cholesky" begin + @testset "with wrapper arguments" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + _A = collect(exp(TS(rand(Te, 4, 4)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) + end end end - end - @testset "without wrapper arguments" begin - _square(A) = A * A' - @testset for Te in (Float64,) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = rand(Te, 4, 4) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) + @testset "without wrapper arguments" begin + _square(A) = A * A' + @testset for Te in (Float64,) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = rand(Te, 4, 4) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) + end end end end -end -@testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,) - A = exp(Symmetric(rand(Te, 4, 4))) - C = cholesky(A) - B = rand(Te, 4, 4) - b = rand(Te, 4) - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) - test_reverse(\, Tret, (C, TC), (_B, TB)) + @testset "Linear solve for `Cholesky`" begin + @testset for Te in (Float64,) + A = exp(Symmetric(rand(Te, 4, 4))) + C = cholesky(A) + B = rand(Te, 4, 4) + b = rand(Te, 4) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) + end end - end - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) - Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + end end end end -end -@testset "Linear solve for triangular matrices" begin - @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), - TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - function f!(Y, A, B, ::T) where T - ldiv!(Y, T(A), B) - return nothing - end - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + @testset "Linear solve for triangular matrices" begin + @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, + UnitLowerTriangular), + TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) + + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where {T} + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + end end - end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), + Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), + Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) - autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 end end end -end @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform @@ -243,7 +262,8 @@ end Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) # Outer rand should be differentiated through, and inner rand and randn should be ignored. - @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) + @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == + ((1.0,),) end end # InternalRules From 4f216cca06a1ef4ddf981a544dff9aca4cb0ea3e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 09:58:52 +0200 Subject: [PATCH 43/61] Revert "Run formatter" This reverts commit 3f8ee74dae64a28fff1003884cc24907cff09440. --- src/internal_rules.jl | 502 ++++++++++++++++++++--------------------- test/internal_rules.jl | 204 ++++++++--------- 2 files changed, 339 insertions(+), 367 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 29be9c93a8..d6e862ff6e 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -69,8 +69,7 @@ end function EnzymeRules.inactive(::typeof(Random.rand), ::Random.AbstractRNG, ::Random.Sampler) return nothing end -function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, - ::Random.Sampler, ::AbstractArray) +function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray) return nothing end function EnzymeRules.inactive(::typeof(Random.randn), args...) @@ -103,8 +102,7 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K,V}, ::K, - ::V) where {K,V<:Integer} +function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} return nothing end @@ -122,41 +120,35 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true @inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T,N}) where {T,N} = N +@inline width(::BatchDuplicated{T, N}) where {T, N} = N @inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N +@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N -@inline width(::Type{Duplicated{T}}) where {T} = 1 -@inline width(::Type{BatchDuplicated{T,N}}) where {T,N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where {T} = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N +@inline width(::Type{Duplicated{T}}) where T = 1 +@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, - x::Duplicated) +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, - ::Type{<:BatchDuplicatedNoNeed}, - x::BatchDuplicated{T,N}) where {T,N} +function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} ntuple(Val(N)) do _ - return deepcopy(x.dval) + deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, - shadow::RT) where {RT<:Union{Integer,Char}} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, - shadow::RT) where {RT<:AbstractFloat} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, - shadow::RT) where {RT<:Array} +@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -170,23 +162,19 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, - x::Duplicated) +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, - x::BatchDuplicated{T,N}) where {T,N} +function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} primal = func.val(x.val) - return BatchDuplicated(primal, - ntuple(Val(N)) do i - return deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) - end) + return BatchDuplicated(primal, ntuple(Val(N)) do i + deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) + end) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, - ::Type{RT}, x::Annotation{Ty}) where {RT,Ty} +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -203,8 +191,9 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - return Enzyme.make_zero(source, - Val(!EnzymeRules.needs_primal(config))) + Enzyme.make_zero(source, + #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + ) end if EnzymeRules.width(config) == 1 @@ -214,8 +203,8 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end -@inline function accumulate_into(into::RT, seen::IdDict, - from::RT)::Tuple{RT,RT} where {RT<:Array} + +@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -230,10 +219,9 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, - from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into + from, RT(0)) + seen[into] = (into+from, RT(0)) end return seen[into] end @@ -248,8 +236,7 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, - x::Annotation{Ty}) where {RT,Ty} +function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else @@ -261,53 +248,43 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{ return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, - fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} +@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, - fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} - return unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) +@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, - ::Type{Const{Nothing}}, body::BodyTy, count, - args::Vararg{Annotation,N}) where {BodyTy,N} - config2 = ReverseModeSplit{false,false,EnzymeRules.width(config), - EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), - map(typeof, args)...) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} + + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, - fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} - return thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) +@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, - fargs::Vararg{Annotation,N}) where {ThunkTy,F,N} - return thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) +@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} + thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, - ::Type{Const{Nothing}}, tapes, body::BodyTy, count, - args::Vararg{Annotation,N}) where {BodyTy,N} - config2 = ReverseModeSplit{false,false,EnzymeRules.width(config), - EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), - map(typeof, args)...) +function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} + + config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() + fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -317,14 +294,16 @@ function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, Libc.free(tapes) end - return ntuple(Val(2 + length(args))) do _ + return ntuple(Val(2+length(args))) do _ Base.@_inline_meta - return nothing + nothing end end + + # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -347,9 +326,8 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, - A::Annotation{AT}, - b::Annotation{BT}) where {RT,AT<:Array,BT<:Array} +function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} + cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -365,7 +343,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - return zero(res) + zero(res) end end @@ -387,32 +365,33 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end - @static if VERSION < v"1.8.0" - UT = Union{LinearAlgebra.Diagonal{eltype(AT),BT}, - LinearAlgebra.LowerTriangular{eltype(AT),AT}, - LinearAlgebra.UpperTriangular{eltype(AT),AT}, - LinearAlgebra.LU{eltype(AT),AT}, - LinearAlgebra.QRCompactWY{eltype(AT),AT}} - else - UT = Union{LinearAlgebra.Diagonal{eltype(AT),BT}, - LinearAlgebra.LowerTriangular{eltype(AT),AT}, - LinearAlgebra.UpperTriangular{eltype(AT),AT}, - LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT),AT,BT,Vector{Int}}} - end +@static if VERSION < v"1.8.0" + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.LowerTriangular{eltype(AT), AT}, + LinearAlgebra.UpperTriangular{eltype(AT), AT}, + LinearAlgebra.LU{eltype(AT), AT}, + LinearAlgebra.QRCompactWY{eltype(AT), AT} + } +else + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT), BT}, + LinearAlgebra.LowerTriangular{eltype(AT), AT}, + LinearAlgebra.UpperTriangular{eltype(AT), AT}, + LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} + } +end - cache = NamedTuple{(Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), - Tuple{typeof(res),typeof(dres),UT,typeof(cache_b)}}((cache_res, dres, - cache_A, - cache_b)) + cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( + (cache_res, dres, cache_A, cache_b) + ) - return EnzymeRules.AugmentedReturn{typeof(retres),typeof(dres),typeof(cache)}(retres, - dres, - cache) + return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, - A::Annotation{<:Array}, b::Annotation{<:Array}) where {RT} +function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT + y, dys, cache_A, cache_b = cache if !EnzymeRules.overwritten(config)[3] @@ -433,7 +412,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, if typeof(A) <: Const ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - return nothing + nothing end else A.dval @@ -450,7 +429,7 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, if typeof(b) <: Const ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - return nothing + nothing end else b.dval @@ -468,22 +447,24 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, dy .= eltype(dy)(0) end - return (nothing, nothing) + return (nothing,nothing) end -const EnzymeTriangulars = Union{UpperTriangular, - LowerTriangular, - UnitUpperTriangular, - UnitLowerTriangular} +const EnzymeTriangulars = Union{ + UpperTriangular, + LowerTriangular, + UnitUpperTriangular, + UnitLowerTriangular +} -function EnzymeRules.augmented_primal(config, - func::Const{typeof(ldiv!)}, - ::Type{RT}, - Y::Annotation{YT}, - A::Annotation{AT}, - B::Annotation{BT}) where {RT,YT<:Array, - AT<:EnzymeTriangulars, - BT<:Array} +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT} +) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -491,19 +472,19 @@ function EnzymeRules.augmented_primal(config, primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal),typeof(shadow),Any}(primal, shadow, - (cache_Y, cache_A, - cache_B)) -end - -function EnzymeRules.reverse(config, - func::Const{typeof(ldiv!)}, - ::Type{RT}, - cache, - Y::Annotation{YT}, - A::Annotation{AT}, - B::Annotation{BT}) where {YT<:Array,RT,AT<:EnzymeTriangulars, - BT<:Array} + return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( + primal, shadow, (cache_Y, cache_A, cache_B)) +end + +function EnzymeRules.reverse( + config, + func::Const{typeof(ldiv!)}, + ::Type{RT}, + cache, + Y::Annotation{YT}, + A::Annotation{AT}, + B::Annotation{BT} +) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache for b in 1:EnzymeRules.width(config) @@ -529,65 +510,62 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" - # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) - function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, - ::Type{RT}, out::Annotation{AT}, - inp::Annotation{BT}) where {RT,AT<:Array, - BT<:Tuple} - primal = if EnzymeRules.needs_primal(config) - out.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) +# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) +function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) + out.dval + else + nothing + end + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} + nr, nc = size(out.val,1), size(out.val,2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 out.dval else - nothing + out.dval[b] end - func.val(out.val, inp.val) - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) - end - - function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, - _, out::Annotation{AT}, - inp::Annotation{BT}) where {RT,AT<:Array,BT<:Tuple} - nr, nc = size(out.val, 1), size(out.val, 2) - for b in 1:EnzymeRules.width(config) - da = if EnzymeRules.width(config) == 1 - out.dval - else - out.dval[b] - end - i = 1 - j = 1 - if (typeof(inp) <: Active) - dinp = ntuple(Val(length(inp.val))) do k - Base.@_inline_meta - res = da[i, j] - da[i, j] = 0 - j += 1 - if j == nc + 1 - i += 1 - j = 1 - end - T = BT.parameters[k] - if T <: AbstractFloat - T(res) - else - T(0) - end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc+1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) end - return (nothing, dinp)::Tuple{Nothing,BT} end + return (nothing, dinp)::Tuple{Nothing, BT} end - return (nothing, nothing) end + return (nothing, nothing) +end end -function EnzymeRules.forward(::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - xs::Duplicated{T}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -600,10 +578,12 @@ function EnzymeRules.forward(::Const{typeof(sort!)}, end end -function EnzymeRules.forward(::Const{typeof(sort!)}, - RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, - xs::BatchDuplicated{T,N}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat},N} +function EnzymeRules.forward( + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -618,11 +598,14 @@ function EnzymeRules.forward(::Const{typeof(sort!)}, end end -function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - xs::Duplicated{T}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat}} + +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -639,23 +622,27 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, return EnzymeRules.AugmentedReturn(primal, shadow, inds) end -function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs...,) where {T<:AbstractArray{<:AbstractFloat}} +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] return (nothing,) end -function EnzymeRules.forward(::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer,OrdinalRange}}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat}} +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -674,11 +661,13 @@ function EnzymeRules.forward(::Const{typeof(partialsort!)}, end end -function EnzymeRules.forward(::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, - xs::BatchDuplicated{T,N}, - k::Const{<:Union{Integer,OrdinalRange}}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat},N} +function EnzymeRules.forward( + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, + xs::BatchDuplicated{T, N}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}, N} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -703,13 +692,14 @@ function EnzymeRules.forward(::Const{typeof(partialsort!)}, end end -function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(partialsort!)}, - RT::Type{<:Union{Const,Active,DuplicatedNoNeed, - Duplicated}}, - xs::Duplicated{T}, - k::Const{<:Union{Integer,OrdinalRange}}; - kwargs...) where {T<:AbstractArray{<:AbstractFloat}} +function EnzymeRules.augmented_primal( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + RT::Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs... + ) where {T <: AbstractArray{<:AbstractFloat}} kv = k.val inds = collect(eachindex(xs.val)) partialsortperm!(inds, xs.val, kv; kwargs...) @@ -728,15 +718,15 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.ConfigWidth{1}, return EnzymeRules.AugmentedReturn(primal, shadow, inds) end -function EnzymeRules.reverse(config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(partialsort!)}, - dret::Union{Active, - Type{<:Union{Const,Active,DuplicatedNoNeed, - Duplicated}}}, - tape, - xs::Duplicated{T}, - k::Const{<:Union{Integer,OrdinalRange}}; - kwargs...,) where {T<:AbstractArray{<:AbstractFloat}} +function EnzymeRules.reverse( + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(partialsort!)}, + dret::Union{Active, Type{<:Union{Const, Active, DuplicatedNoNeed, Duplicated}}}, + tape, + xs::Duplicated{T}, + k::Const{<:Union{Integer, OrdinalRange}}; + kwargs..., + ) where {T <: AbstractArray{<:AbstractFloat}} inds = tape kv = k.val if dret isa Active @@ -761,7 +751,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta - return zero(A.val) + zero(A.val) end else N == 1 ? (A.dval,) : A.dval @@ -808,11 +798,13 @@ function _cholesky_forward(C::Cholesky, Ȧ) end end -function EnzymeRules.forward(func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const,Duplicated}}, - fact::Annotation{<:Cholesky}, - B::Annotation{<:AbstractVecOrMat}; - kwargs...) +function EnzymeRules.forward( + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const, Duplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs... +) if B isa Const return func.val(fact.val, B.val; kwargs...) else @@ -860,13 +852,13 @@ function _ldiv_Cholesky_forward!(L, U, B, dL, dU, dB) return B, dB end -function EnzymeRules.augmented_primal(config, - func::Const{typeof(cholesky)}, - RT::Type, - A::Annotation{<:Union{Matrix, - LinearAlgebra.RealHermSym{<:Real, - <:Matrix}}}; - kwargs...) +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(cholesky)}, + RT::Type, + A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; + kwargs...) + fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -884,7 +876,7 @@ function EnzymeRules.augmented_primal(config, else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - return Enzyme.make_zero(fact) + Enzyme.make_zero(fact) end end end @@ -893,14 +885,13 @@ function EnzymeRules.augmented_primal(config, return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache) end -function EnzymeRules.reverse(config, - ::Const{typeof(cholesky)}, - RT::Type, - cache, - A::Annotation{<:Union{Matrix, - LinearAlgebra.RealHermSym{<:Real, - <:Matrix}}}; - kwargs...) +function EnzymeRules.reverse( + config, + ::Const{typeof(cholesky)}, + RT::Type, + cache, + A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; + kwargs...) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval @@ -957,15 +948,15 @@ function _realifydiag!(A) return A end -function EnzymeRules.augmented_primal(config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicatedNoNeed, - BatchDuplicated}}, - A::Annotation{<:Cholesky}, - B::Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicatedNoNeed,BatchDuplicated}; - kwargs...) +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, + + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... +) cache_B = if !isa(A, Const) && !isa(B, Const) copy(B.val) else @@ -984,15 +975,16 @@ function EnzymeRules.augmented_primal(config, return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) end -function EnzymeRules.reverse(config, - func::Const{typeof(ldiv!)}, - dret, - cache, - A::Annotation{<:Cholesky}, - B::Union{Const,DuplicatedNoNeed,Duplicated, - BatchDuplicatedNoNeed,BatchDuplicated}; - kwargs...) - if !isa(B, Const) +function EnzymeRules.reverse( + config, + func::Const{typeof(ldiv!)}, + dret, + cache, + A::Annotation{<:Cholesky}, + B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; + kwargs... +) + if !isa(B, Const) (cache_A, cache_B) = cache U = cache_A.U Z = isa(A, Const) ? nothing : U' \ cache_B diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 048db4f681..7bae91b971 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -7,7 +7,7 @@ using FiniteDifferences using LinearAlgebra using SparseArrays using Test -using Random: Random +import Random struct TPair a::Float64 @@ -18,7 +18,7 @@ function sorterrfn(t, x) function lt(a, b) return a.a < b.a end - return first(sortperm(t; lt=lt)) * x + return first(sortperm(t, lt=lt)) * x end @testset "Sort rules" begin @@ -29,12 +29,10 @@ end end @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == - (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == - (var"1"=0.0, var"2"=0.0) + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -44,8 +42,7 @@ end end @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == - (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 function f3(x) @@ -54,8 +51,7 @@ end end @test autodiff(Forward, f3, Duplicated(1.5, 1.0))[1] == 1.0 - @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == - (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f3, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f3, Active(1.5))[1][1] == 1.0 @test autodiff(Reverse, f3, Active(2.5))[1][1] == 0.0 @@ -67,15 +63,13 @@ end @test autodiff(Forward, f4, Duplicated(1.5, 1.0))[1] == 1.5 @static if VERSION < v"1.7-" || VERSION >= v"1.8-" - @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == - (var"1"=1.5, var"2"=3.0) + @test autodiff(Forward, f4, BatchDuplicated(1.5, (1.0, 2.0)))[1] == (var"1"=1.5, var"2"=3.0) end @test autodiff(Reverse, f4, Active(1.5))[1][1] == 1.5 @test autodiff(Reverse, f4, Active(4.0))[1][1] == 0.5 @test autodiff(Reverse, f4, Active(6.0))[1][1] == 0.0 - dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], - [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) @test res[1][2] ≈ 3 @@ -93,9 +87,7 @@ end b = Float64[11, 13] db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, - Duplicated, Duplicated{typeof(A)}, - Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) @@ -112,9 +104,7 @@ end db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, - Duplicated, Const{typeof(A)}, - Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) @@ -130,9 +120,7 @@ end dA = zero(A) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, - Duplicated, Duplicated{typeof(A)}, - Const{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) @@ -148,109 +136,102 @@ end end @static if VERSION > v"1.8" - @testset "cholesky" begin - @testset "with wrapper arguments" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - _A = collect(exp(TS(rand(Te, 4, 4)))) - A = TS(_A, uplo) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) - end +@testset "cholesky" begin + @testset "with wrapper arguments" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + _A = collect(exp(TS(rand(Te, 4, 4)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) end end - @testset "without wrapper arguments" begin - _square(A) = A * A' - @testset for Te in (Float64,) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = rand(Te, 4, 4) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) - end + end + @testset "without wrapper arguments" begin + _square(A) = A * A' + @testset for Te in (Float64,) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = rand(Te, 4, 4) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) end end end +end - @testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,) - A = exp(Symmetric(rand(Te, 4, 4))) - C = cholesky(A) - B = rand(Te, 4, 4) - b = rand(Te, 4) - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) - test_reverse(\, Tret, (C, TC), (_B, TB)) - end +@testset "Linear solve for `Cholesky`" begin + @testset for Te in (Float64,) + A = exp(Symmetric(rand(Te, 4, 4))) + C = cholesky(A) + B = rand(Te, 4, 4) + b = rand(Te, 4) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) end - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) - Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) - end + end + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) end end end +end - @testset "Linear solve for triangular matrices" begin - @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, - UnitLowerTriangular), - TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) - - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - function f!(Y, A, B, ::T) where {T} - ldiv!(Y, T(A), B) - return nothing - end - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) - end +@testset "Linear solve for triangular matrices" begin + @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), + TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where T + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing - end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), - Duplicated(B, dB1)) - autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), - Duplicated(B, dB2)) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 + end + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) + autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 end end end +end @testset "rand and randn rules" begin # Distributed as x + unit normal + uniform @@ -262,8 +243,7 @@ end Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) # Outer rand should be differentiated through, and inner rand and randn should be ignored. - @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == - ((1.0,),) + @test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),) end end # InternalRules From 37fde99b40e3429a4e136aff2f1c78c2d4d38332 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 10:03:32 +0200 Subject: [PATCH 44/61] Format changes --- src/internal_rules.jl | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d6e862ff6e..f7518b8f24 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -948,15 +948,15 @@ function _realifydiag!(A) return A end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) +function EnzymeRules.augmented_primal(config, + func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed, + BatchDuplicated}}, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) cache_B = if !isa(A, Const) && !isa(B, Const) copy(B.val) else @@ -975,16 +975,15 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) end -function EnzymeRules.reverse( - config, - func::Const{typeof(ldiv!)}, - dret, - cache, - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... -) - if !isa(B, Const) +function EnzymeRules.reverse(config, + func::Const{typeof(ldiv!)}, + dret, + cache, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated, + BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs...) + if !isa(B, Const) (cache_A, cache_B) = cache U = cache_A.U Z = isa(A, Const) ? nothing : U' \ cache_B From 0b15bec03eade5f1bcddf88c2670e507390f259f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 10:09:17 +0200 Subject: [PATCH 45/61] Format other lines --- src/internal_rules.jl | 45 ++++++++++++++++++++--------------------- test/internal_rules.jl | 46 ++++++++++++++++++++++-------------------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index f7518b8f24..2af7b03ee2 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -751,7 +751,7 @@ function EnzymeRules.forward(::Const{typeof(cholesky)}, RT::Type, A; kwargs...) dA = if isa(A, Const) ntuple(Val(N)) do i Base.@_inline_meta - zero(A.val) + return zero(A.val) end else N == 1 ? (A.dval,) : A.dval @@ -798,13 +798,11 @@ function _cholesky_forward(C::Cholesky, Ȧ) end end -function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, Duplicated}}, - fact::Annotation{<:Cholesky}, - B::Annotation{<:AbstractVecOrMat}; - kwargs... -) +function EnzymeRules.forward(func::Const{typeof(ldiv!)}, + RT::Type{<:Union{Const,Duplicated}}, + fact::Annotation{<:Cholesky}, + B::Annotation{<:AbstractVecOrMat}; + kwargs...) if B isa Const return func.val(fact.val, B.val; kwargs...) else @@ -852,13 +850,13 @@ function _ldiv_Cholesky_forward!(L, U, B, dL, dU, dB) return B, dB end -function EnzymeRules.augmented_primal( - config, - func::Const{typeof(cholesky)}, - RT::Type, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - +function EnzymeRules.augmented_primal(config, + func::Const{typeof(cholesky)}, + RT::Type, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSym{<:Real, + <:Matrix}}}; + kwargs...) fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -876,7 +874,7 @@ function EnzymeRules.augmented_primal( else ntuple(Val(EnzymeRules.width(config))) do i Base.@_inline_meta - Enzyme.make_zero(fact) + return Enzyme.make_zero(fact) end end end @@ -885,13 +883,14 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(fact_returned, dfact, cache) end -function EnzymeRules.reverse( - config, - ::Const{typeof(cholesky)}, - RT::Type, - cache, - A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) +function EnzymeRules.reverse(config, + ::Const{typeof(cholesky)}, + RT::Type, + cache, + A::Annotation{<:Union{Matrix, + LinearAlgebra.RealHermSym{<:Real, + <:Matrix}}}; + kwargs...) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7bae91b971..f033affd4c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -161,32 +161,34 @@ end end end -@testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,) - A = exp(Symmetric(rand(Te, 4, 4))) - C = cholesky(A) - B = rand(Te, 4, 4) - b = rand(Te, 4) - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) - test_reverse(\, Tret, (C, TC), (_B, TB)) + @testset "Linear solve for `Cholesky`" begin + @testset for Te in (Float64,) + A = exp(Symmetric(rand(Te, 4, 4))) + C = cholesky(A) + B = rand(Te, 4, 4) + b = rand(Te, 4) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) + test_reverse(\, Tret, (C, TC), (_B, TB)) + end end - end - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) - @testset "$(size(_B))" for _B in (B, b) - are_activities_compatible(Tret, TC, TB) || continue - # Non-uniform activities are disabled due to unresolved questions - Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) - Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), + Tret in (Const, Duplicated) + + @testset "$(size(_B))" for _B in (B, b) + are_activities_compatible(Tret, TC, TB) || continue + # Non-uniform activities are disabled due to unresolved questions + Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + end end end end -end @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), From 6637b0a2816db0b5ac8138819c216d2076367675 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 10:29:43 +0200 Subject: [PATCH 46/61] Format remaining lines --- test/internal_rules.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index f033affd4c..10fe7c49b5 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -136,30 +136,30 @@ end end @static if VERSION > v"1.8" -@testset "cholesky" begin - @testset "with wrapper arguments" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - _A = collect(exp(TS(rand(Te, 4, 4)))) - A = TS(_A, uplo) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) + @testset "cholesky" begin + @testset "with wrapper arguments" begin + @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + _A = collect(exp(TS(rand(Te, 4, 4)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA)) + test_reverse(cholesky, Tret, (A, TA)) + end end end - end - @testset "without wrapper arguments" begin - _square(A) = A * A' - @testset for Te in (Float64,) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = rand(Te, 4, 4) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) + @testset "without wrapper arguments" begin + _square(A) = A * A' + @testset for Te in (Float64,) + @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) + A = rand(Te, 4, 4) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) + end end end end -end @testset "Linear solve for `Cholesky`" begin @testset for Te in (Float64,) From 5983d9ecefadef8c0ba5b89fd98f415542fa7150 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 15:26:50 +0200 Subject: [PATCH 47/61] Generalize `ldiv!` rule to `uplo = :L` --- src/internal_rules.jl | 7 ++++++- test/internal_rules.jl | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 2af7b03ee2..a6f27a62ce 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -995,7 +995,12 @@ function EnzymeRules.reverse(config, ∂B = U \ dZ Ā = -dZ * Y' - Z * ∂B' dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b] - dA.factors .+= UpperTriangular(Ā) + if A.val.uplo === 'U' + dA.factors .+= UpperTriangular(Ā) + else + dA.factors .+= LowerTriangular(Ā') + end + end end end diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 10fe7c49b5..a256bec586 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -162,8 +162,9 @@ end end @testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,) - A = exp(Symmetric(rand(Te, 4, 4))) + @testset for Te in (Float64,), uplo in (:U, :L) + _A = collect(exp(Symmetric(rand(Te, 4, 4)))) + A = Symmetric(_A, uplo) C = cholesky(A) B = rand(Te, 4, 4) b = rand(Te, 4) From f7d04c29a2e6ae41b54aef9bc309c5f5b37af348 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 15:54:14 +0200 Subject: [PATCH 48/61] Link to issue about open question --- test/internal_rules.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index a256bec586..03ceba1b83 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -174,6 +174,7 @@ end @testset "$(size(_B))" for _B in (B, b) are_activities_compatible(Tret, TC, TB) || continue # Non-uniform activities are disabled due to unresolved questions + # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) test_reverse(\, Tret, (C, TC), (_B, TB)) end @@ -184,6 +185,7 @@ end @testset "$(size(_B))" for _B in (B, b) are_activities_compatible(Tret, TC, TB) || continue # Non-uniform activities are disabled due to unresolved questions + # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) end From b0c9d6a27c0db12948f94e5cd9697c28a468eb50 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Fri, 3 May 2024 16:05:55 +0200 Subject: [PATCH 49/61] Generalize tests --- test/internal_rules.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 03ceba1b83..38a31e75dd 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -162,10 +162,8 @@ end end @testset "Linear solve for `Cholesky`" begin - @testset for Te in (Float64,), uplo in (:U, :L) - _A = collect(exp(Symmetric(rand(Te, 4, 4)))) - A = Symmetric(_A, uplo) - C = cholesky(A) + @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') + C = Cholesky(I + rand(Te, 4, 4), uplo, 0) B = rand(Te, 4, 4) b = rand(Te, 4) @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), From 308b6268239752b74258dac534929c5f63d2e301 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 4 May 2024 20:50:52 +0200 Subject: [PATCH 50/61] Refactor rule --- src/internal_rules.jl | 45 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index a6f27a62ce..b998c33a91 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -799,7 +799,7 @@ function _cholesky_forward(C::Cholesky, Ȧ) end function EnzymeRules.forward(func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const,Duplicated}}, + RT::Type{<:Union{Const,Duplicated,BatchDuplicated}}, fact::Annotation{<:Cholesky}, B::Annotation{<:AbstractVecOrMat}; kwargs...) @@ -809,20 +809,31 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, N = width(B) retval = B.val + L = fact.val.L + U = fact.val.U + + ldiv!(L, B.val) + for b in 1:N + dB = N == 1 ? B.dval : B.dval[b] + if !(fact isa Const) + dL = N == 1 ? fact.dval.L : fact.dval[b].L + mul!(dB, dL, B.val, -1, 1) + end + ldiv!(L, dB) + end + ldiv!(U, B.val) + for b in 1:N + dB = N == 1 ? B.dval : B.dval[b] + if !(fact isa Const) + dU = N == 1 ? fact.dval.U : fact.dval[b].U + mul!(dB, dU, B.val, -1, 1) + end + ldiv!(U, dB) + end + dretvals = ntuple(Val(N)) do b Base.@_inline_meta dB = N == 1 ? B.dval : B.dval[b] - if fact isa Const - ldiv!(fact.val, B.val) - ldiv!(fact.val, dB) - else - dfact = N == 1 ? fact.dval : fact.dval[b] - L = fact.val.L - U = fact.val.U - dL = dfact.L - dU = dfact.U - _ldiv_Cholesky_forward!(L, U, B.val, dL, dU, dB) - end return dB end @@ -840,16 +851,6 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)}, end end -function _ldiv_Cholesky_forward!(L, U, B, dL, dU, dB) - ldiv!(L, B) - mul!(dB, dL, B, -1, 1) - ldiv!(L, dB) - ldiv!(U, B) - mul!(dB, dU, B, -1, 1) - ldiv!(U, dB) - return B, dB -end - function EnzymeRules.augmented_primal(config, func::Const{typeof(cholesky)}, RT::Type, From 20877a0865245c1624611f0d9fdf062fd204549e Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 4 May 2024 20:52:51 +0200 Subject: [PATCH 51/61] Add tests for ` BatchDuplicated` --- test/internal_rules.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 38a31e75dd..82d7adfa25 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -166,26 +166,30 @@ end C = Cholesky(I + rand(Te, 4, 4), uplo, 0) B = rand(Te, 4, 4) b = rand(Te, 4) - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) + @testset for TC in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated), + Tret in (Const, Duplicated, BatchDuplicated) @testset "$(size(_B))" for _B in (B, b) are_activities_compatible(Tret, TC, TB) || continue # Non-uniform activities are disabled due to unresolved questions # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 - Tret == TC == TB && test_forward(\, Tret, (C, TC), (_B, TB)) + Tret == TC == TB || continue + test_forward(\, Tret, (C, TC), (_B, TB)) test_reverse(\, Tret, (C, TC), (_B, TB)) end end - @testset for TC in (Const, Duplicated), TB in (Const, Duplicated), - Tret in (Const, Duplicated) + @testset for TC in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated), + Tret in (Const, Duplicated, BatchDuplicated) @testset "$(size(_B))" for _B in (B, b) are_activities_compatible(Tret, TC, TB) || continue # Non-uniform activities are disabled due to unresolved questions # see https://github.com/EnzymeAD/Enzyme.jl/issues/1411 - Tret == TC == TB && test_forward(ldiv!, Tret, (C, TC), (_B, TB)) - Tret == TB && test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) + Tret == TC == TB || continue + test_forward(ldiv!, Tret, (C, TC), (_B, TB)) + test_reverse(ldiv!, Tret, (C, TC), (_B, TB)) end end end From 3171823bb99a7d613abbb3173f58592d75cc9bbc Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 4 May 2024 23:04:57 +0200 Subject: [PATCH 52/61] Include all activities --- test/internal_rules.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 82d7adfa25..368b1ae735 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -162,12 +162,14 @@ end end @testset "Linear solve for `Cholesky`" begin + activities = (Const, Duplicated, DuplicatedNoNeed, BatchDuplicated, + BatchDuplicatedNoNeed) @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') C = Cholesky(I + rand(Te, 4, 4), uplo, 0) B = rand(Te, 4, 4) b = rand(Te, 4) - @testset for TC in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated), + @testset for TC in activities, + TB in activities, Tret in (Const, Duplicated, BatchDuplicated) @testset "$(size(_B))" for _B in (B, b) @@ -179,8 +181,8 @@ end test_reverse(\, Tret, (C, TC), (_B, TB)) end end - @testset for TC in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated), + @testset for TC in activities, + TB in activities, Tret in (Const, Duplicated, BatchDuplicated) @testset "$(size(_B))" for _B in (B, b) From 8be3391cf71c8a37203f5dd65e0afe591e6564af Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sat, 18 May 2024 23:31:48 +0200 Subject: [PATCH 53/61] Include more activities --- test/internal_rules.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 011548f189..7d48b3f178 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -137,10 +137,11 @@ end @static if VERSION > v"1.8" @testset "cholesky" begin + activities = (Const, Duplicated, BatchDuplicated,) @testset "with wrapper arguments" begin @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - _A = collect(exp(TS(rand(Te, 4, 4)))) + @testset for TA in activities, Tret in activities + _A = collect(exp(TS(rand(Te, 5, 5)))) A = TS(_A, uplo) are_activities_compatible(Tret, TA) || continue test_forward(cholesky, Tret, (A, TA)) @@ -151,8 +152,8 @@ end @testset "without wrapper arguments" begin _square(A) = A * A' @testset for Te in (Float64,) - @testset for TA in (Const, Duplicated), Tret in (Const, Duplicated) - A = rand(Te, 4, 4) + @testset for TA in activities, Tret in activities + A = rand(Te, 5, 5) are_activities_compatible(Tret, TA) || continue test_forward(cholesky ∘ _square, Tret, (A, TA)) test_reverse(cholesky ∘ _square, Tret, (A, TA)) @@ -165,8 +166,8 @@ end activities = (Const, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed) @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') - C = Cholesky(I + rand(Te, 4, 4), uplo, 0) - B = rand(Te, 4, 4) + C = Cholesky(I + rand(Te, 5, 5), uplo, 0) + B = rand(Te, 5, 5) b = rand(Te, 4) @testset for TC in activities, TB in activities, From b31ed351d38fe2f26897dcf80f3029574d7db954 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 19 May 2024 00:04:38 +0200 Subject: [PATCH 54/61] Fix typo --- test/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 7d48b3f178..ee3ce3a238 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -168,7 +168,7 @@ end @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') C = Cholesky(I + rand(Te, 5, 5), uplo, 0) B = rand(Te, 5, 5) - b = rand(Te, 4) + b = rand(Te, 5) @testset for TC in activities, TB in activities, Tret in (Const, Duplicated, BatchDuplicated) From 5bf00578be51755217bdaef52de64939cbefd796 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Sun, 19 May 2024 01:30:56 +0200 Subject: [PATCH 55/61] Write testsets in a different way --- test/internal_rules.jl | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index ee3ce3a238..292a0f2ce0 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -137,27 +137,23 @@ end @static if VERSION > v"1.8" @testset "cholesky" begin - activities = (Const, Duplicated, BatchDuplicated,) - @testset "with wrapper arguments" begin - @testset for Te in (Float64,), TS in (Symmetric, Hermitian), uplo in (:U, :L) - @testset for TA in activities, Tret in activities - _A = collect(exp(TS(rand(Te, 5, 5)))) - A = TS(_A, uplo) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky, Tret, (A, TA)) - test_reverse(cholesky, Tret, (A, TA)) - end + activities = (Const, Duplicated, BatchDuplicated) + _square(A) = A * adjoint(A) + @testset for (Te, TSs) in ( + Float64 => (Symmetric, Hermitian), + ), TA in activities, Tret in activities + @testset "without wrapper arguments" begin + A = rand(Te, 5, 5) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) end - end - @testset "without wrapper arguments" begin - _square(A) = A * A' - @testset for Te in (Float64,) - @testset for TA in activities, Tret in activities - A = rand(Te, 5, 5) - are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) - end + @testset "with wrapper arguments" for TS in TSs, uplo in (:U, :L) + _A = collect(exp(TS(I + rand(Te, 5, 5)))) + A = TS(_A, uplo) + are_activities_compatible(Tret, TA) || continue + test_forward(cholesky, Tret, (A, TA); fdm=FiniteDifferences.forward_fdm(5, 1)) + test_reverse(cholesky, Tret, (A, TA)) end end end From 250e24f424960b20e9eb07e0bddb49256295a0ce Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 20 May 2024 23:20:06 +0200 Subject: [PATCH 56/61] Test complex element type --- test/internal_rules.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 292a0f2ce0..baebfe27c9 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -141,6 +141,7 @@ end _square(A) = A * adjoint(A) @testset for (Te, TSs) in ( Float64 => (Symmetric, Hermitian), + ComplexF64 => (Hermitian,), ), TA in activities, Tret in activities @testset "without wrapper arguments" begin A = rand(Te, 5, 5) From 62dbbc2b4e8c2ff12b4c417d4b084d507c5af8ee Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 20 May 2024 23:28:19 +0200 Subject: [PATCH 57/61] Add comment regarding `I` in test --- test/internal_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index baebfe27c9..01b8c9e7c9 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -163,7 +163,7 @@ end activities = (Const, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed) @testset for Te in (Float64, ComplexF64), uplo in ('L', 'U') - C = Cholesky(I + rand(Te, 5, 5), uplo, 0) + C = Cholesky(I + rand(Te, 5, 5), uplo, 0) # add `I` for numerical stability B = rand(Te, 5, 5) b = rand(Te, 5) @testset for TC in activities, From 75c076645656380efa9eb5160b107b9dbb9ca80c Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Mon, 20 May 2024 23:42:17 +0200 Subject: [PATCH 58/61] Work around issue #1456 --- test/internal_rules.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 01b8c9e7c9..99b900af3a 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -138,6 +138,8 @@ end @static if VERSION > v"1.8" @testset "cholesky" begin activities = (Const, Duplicated, BatchDuplicated) + # Workaround for issue #1456: + _realifydiag(A) = (A[diagind(A)] .= real(A[diagind(A)]); return A) _square(A) = A * adjoint(A) @testset for (Te, TSs) in ( Float64 => (Symmetric, Hermitian), @@ -146,8 +148,8 @@ end @testset "without wrapper arguments" begin A = rand(Te, 5, 5) are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _square, Tret, (A, TA)) + test_forward(cholesky ∘ _realifydiag ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _realifydiag ∘ _square, Tret, (A, TA)) end @testset "with wrapper arguments" for TS in TSs, uplo in (:U, :L) _A = collect(exp(TS(I + rand(Te, 5, 5)))) From 4d02587ea9143f4f955674904620ef43b192eee4 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 21 May 2024 00:12:39 +0200 Subject: [PATCH 59/61] Increase coverage of rules --- src/internal_rules.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index b998c33a91..378f7d12d5 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -855,8 +855,7 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(cholesky)}, RT::Type, A::Annotation{<:Union{Matrix, - LinearAlgebra.RealHermSym{<:Real, - <:Matrix}}}; + LinearAlgebra.RealHermSymComplexHerm}}; kwargs...) fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) @@ -889,8 +888,7 @@ function EnzymeRules.reverse(config, RT::Type, cache, A::Annotation{<:Union{Matrix, - LinearAlgebra.RealHermSym{<:Real, - <:Matrix}}}; + LinearAlgebra.RealHermSymComplexHerm}}; kwargs...) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache @@ -898,7 +896,7 @@ function EnzymeRules.reverse(config, dfacts = EnzymeRules.width(config) == 1 ? (dfact,) : dfact for (dA, dfact) in zip(dAs, dfacts) - _dA = dA isa LinearAlgebra.RealHermSym ? dA.data : dA + _dA = dA isa LinearAlgebra.RealHermSymComplexHerm ? dA.data : dA if _dA !== dfact.factors Ā = _cholesky_pullback_shared_code(fact, dfact) _dA .+= Ā @@ -914,10 +912,12 @@ end # Copyright (c) 2018: Jarrett Revels. # https://github.com/JuliaDiff/ChainRules.jl/blob/9f1817a22404259113e230bef149a54d379a660b/src/rulesets/LinearAlgebra/factorization.jl#L507-L528 function _cholesky_pullback_shared_code(C, ΔC) + Δfactors = ΔC.factors Ā = similar(C.factors) if C.uplo === 'U' U = C.U Ū = ΔC.U + Ū = eltype(U) <: Real ? real(UpperTriangular(Δfactors)) : UpperTriangular(Δfactors) mul!(Ā, Ū, U') LinearAlgebra.copytri!(Ā, 'U', true) eltype(Ā) <: Real || _realifydiag!(Ā) @@ -928,6 +928,7 @@ function _cholesky_pullback_shared_code(C, ΔC) else # C.uplo === 'L' L = C.L L̄ = ΔC.L + L̄ = eltype(L) <: Real ? real(LowerTriangular(Δfactors)) : LowerTriangular(Δfactors) mul!(Ā, L', L̄) LinearAlgebra.copytri!(Ā, 'L', true) eltype(Ā) <: Real || _realifydiag!(Ā) From f5c7abf128e7ed7c161cddd11aef2a95bb469b7a Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 21 May 2024 00:13:45 +0200 Subject: [PATCH 60/61] SImplify test --- test/internal_rules.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 99b900af3a..4b361abb05 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -138,8 +138,6 @@ end @static if VERSION > v"1.8" @testset "cholesky" begin activities = (Const, Duplicated, BatchDuplicated) - # Workaround for issue #1456: - _realifydiag(A) = (A[diagind(A)] .= real(A[diagind(A)]); return A) _square(A) = A * adjoint(A) @testset for (Te, TSs) in ( Float64 => (Symmetric, Hermitian), @@ -148,8 +146,9 @@ end @testset "without wrapper arguments" begin A = rand(Te, 5, 5) are_activities_compatible(Tret, TA) || continue - test_forward(cholesky ∘ _realifydiag ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ _realifydiag ∘ _square, Tret, (A, TA)) + # `Enzyme._realifydiag!` is a workaround for issue #1456: + test_forward(cholesky ∘ Enzyme._realifydiag! ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ Enzyme._realifydiag! ∘ _square, Tret, (A, TA)) end @testset "with wrapper arguments" for TS in TSs, uplo in (:U, :L) _A = collect(exp(TS(I + rand(Te, 5, 5)))) From 510c7dddc6ef989f942e274d3f4a824a28591a85 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Tue, 21 May 2024 00:27:43 +0200 Subject: [PATCH 61/61] Fix test --- test/internal_rules.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 4b361abb05..0af32751ef 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -138,7 +138,11 @@ end @static if VERSION > v"1.8" @testset "cholesky" begin activities = (Const, Duplicated, BatchDuplicated) - _square(A) = A * adjoint(A) + function _square(A) + S = A * adjoint(A) + S[diagind(S)] .= real.(S[diagind(S)]) # workaround for issue #1456: + return S + end @testset for (Te, TSs) in ( Float64 => (Symmetric, Hermitian), ComplexF64 => (Hermitian,), @@ -146,9 +150,8 @@ end @testset "without wrapper arguments" begin A = rand(Te, 5, 5) are_activities_compatible(Tret, TA) || continue - # `Enzyme._realifydiag!` is a workaround for issue #1456: - test_forward(cholesky ∘ Enzyme._realifydiag! ∘ _square, Tret, (A, TA)) - test_reverse(cholesky ∘ Enzyme._realifydiag! ∘ _square, Tret, (A, TA)) + test_forward(cholesky ∘ _square, Tret, (A, TA)) + test_reverse(cholesky ∘ _square, Tret, (A, TA)) end @testset "with wrapper arguments" for TS in TSs, uplo in (:U, :L) _A = collect(exp(TS(I + rand(Te, 5, 5))))