diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 697ad00eab..da1553dcef 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -99,7 +99,9 @@ end function EnzymeRules.inactive_noinl(::typeof(Base.size), args...) return nothing end -function EnzymeRules.inactive_noinl(::typeof(Base.setindex!), ::IdDict{K, V}, ::K, ::V) where {K, V <:Integer} +function EnzymeRules.inactive_noinl( + ::typeof(Base.setindex!), ::IdDict{K,V}, ::K, ::V +) where {K,V<:Integer} return nothing end @@ -117,35 +119,45 @@ end @inline EnzymeRules.inactive_type(v::Type{T}) where {T<:AbstractString} = true @inline width(::Duplicated) = 1 -@inline width(::BatchDuplicated{T, N}) where {T, N} = N +@inline width(::BatchDuplicated{T,N}) where {T,N} = N @inline width(::DuplicatedNoNeed) = 1 -@inline width(::BatchDuplicatedNoNeed{T, N}) where {T, N} = N +@inline width(::BatchDuplicatedNoNeed{T,N}) where {T,N} = N -@inline width(::Type{Duplicated{T}}) where T = 1 -@inline width(::Type{BatchDuplicated{T, N}}) where {T, N} = N -@inline width(::Type{DuplicatedNoNeed{T}}) where T = 1 -@inline width(::Type{BatchDuplicatedNoNeed{T, N}}) where {T, N} = N +@inline width(::Type{Duplicated{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicated{T,N}}) where {T,N} = N +@inline width(::Type{DuplicatedNoNeed{T}}) where {T} = 1 +@inline width(::Type{BatchDuplicatedNoNeed{T,N}}) where {T,N} = N # Note all of these forward mode definitions do not support runtime activity as # the do not keep the primal if shadow(x.y) == primal(x.y) -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) +function EnzymeRules.forward( + ::Const{typeof(Base.deepcopy)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated +) return deepcopy(x.dval) end -function EnzymeRules.forward(::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N} +function EnzymeRules.forward( + ::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T,N} +) where {T,N} ntuple(Val(N)) do _ deepcopy(x.dval) end end # Deepcopy preserving the primal if runtime inactive -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Union{Integer, Char}} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:Union{Integer,Char}} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: AbstractFloat} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:AbstractFloat} return Base.deepcopy_internal(shadow, seen) end -@inline function deepcopy_rtact(copied::RT, primal::RT, seen::IdDict, shadow::RT) where {RT <: Array} +@inline function deepcopy_rtact( + copied::RT, primal::RT, seen::IdDict, shadow::RT +) where {RT<:Array} if !haskey(seen, shadow) if primal === shadow return seen[shadow] = copied @@ -159,19 +171,28 @@ end return seen[shadow] end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated) +function EnzymeRules.forward( + func::Const{typeof(Base.deepcopy)}, ::Type{<:Duplicated}, x::Duplicated +) primal = func.val(x.val) return Duplicated(primal, deepcopy_rtact(primal, x.val, IdDict(), x.dval)) end -function EnzymeRules.forward(func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N} +function EnzymeRules.forward( + func::Const{typeof(Base.deepcopy)}, ::Type{<:BatchDuplicated}, x::BatchDuplicated{T,N} +) where {T,N} primal = func.val(x.val) - return BatchDuplicated(primal, ntuple(Val(N)) do i - deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) - end) + return BatchDuplicated( + primal, + ntuple(Val(N)) do i + deepcopy_rtact(primal, x.val, IdDict(), x.dval[i]) + end, + ) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.augmented_primal( + config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, x::Annotation{Ty} +) where {RT,Ty} primal = if EnzymeRules.needs_primal(config) func.val(x.val) else @@ -188,8 +209,9 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} shadow = ntuple(Val(EnzymeRules.width(config))) do _ Base.@_inline_meta - Enzyme.make_zero(source, - #=copy_if_inactive=#Val(!EnzymeRules.needs_primal(config)) + Enzyme.make_zero( + source, + Val(!EnzymeRules.needs_primal(config)), #=copy_if_inactive=# ) end @@ -200,8 +222,9 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.deepcopy)} return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:Array} +@inline function accumulate_into( + into::RT, seen::IdDict, from::RT +)::Tuple{RT,RT} where {RT<:Array} if Enzyme.Compiler.guaranteed_const(RT) return (into, from) end @@ -216,9 +239,11 @@ end return seen[into] end -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT<:AbstractFloat} +@inline function accumulate_into( + into::RT, seen::IdDict, from::RT +)::Tuple{RT,RT} where {RT<:AbstractFloat} if !haskey(seen, into) - seen[into] = (into+from, RT(0)) + seen[into] = (into + from, RT(0)) end return seen[into] end @@ -233,7 +258,9 @@ end return seen[into] end -function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty}) where {RT, Ty} +function EnzymeRules.reverse( + config, func::Const{typeof(Base.deepcopy)}, ::Type{RT}, shadow, x::Annotation{Ty} +) where {RT,Ty} if EnzymeRules.width(config) == 1 accumulate_into(x.dval, IdDict(), shadow) else @@ -245,43 +272,80 @@ function EnzymeRules.reverse(config, func::Const{typeof(Base.deepcopy)}, ::Type{ return (nothing,) end -@inline function pmap_fwd(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} +@inline function pmap_fwd( + idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} @inbounds tapes[idx] = thunk(f, Const(idx), fargs...)[1] end -@inline function pmap_fwd(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) +@inline function pmap_fwd( + idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return unsafe_store!(tapes, thunk(f, Const(idx), fargs...)[1], idx) end -function EnzymeRules.augmented_primal(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.augmented_primal( + config, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + }() + fwd_thunk, rev_thunk = autodiff_thunk( + config2, BodyTy, Const, typeof(count), map(typeof, args)... + ) TapeType = EnzymeRules.tape_type(fwd_thunk) tapes = if Enzyme.Compiler.any_jltypes(TapeType) Vector{TapeType}(undef, count.val) else - Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType)*count.val)) + Base.unsafe_convert(Ptr{TapeType}, Libc.malloc(sizeof(TapeType) * count.val)) end Enzyme.pmap(pmap_fwd, count.val, tapes, fwd_thunk, body, args...) return EnzymeRules.AugmentedReturn(nothing, nothing, tapes) end -@inline function pmap_rev(idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) +@inline function pmap_rev( + idx, tapes::Vector, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., @inbounds tapes[idx]) end -@inline function pmap_rev(idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation, N}) where {ThunkTy, F, N} - thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) +@inline function pmap_rev( + idx, tapes::Ptr, thunk::ThunkTy, f::F, fargs::Vararg{Annotation,N} +) where {ThunkTy,F,N} + return thunk(f, Const(idx), fargs..., unsafe_load(tapes, idx)) end -function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, tapes, body::BodyTy, count, args::Vararg{Annotation, N}) where {BodyTy, N} - - config2 = ReverseModeSplit{false, false, EnzymeRules.width(config), EnzymeRules.overwritten(config)[2:end],InlineABI}() - fwd_thunk, rev_thunk = autodiff_thunk(config2, BodyTy, Const, typeof(count), map(typeof, args)...) +function EnzymeRules.reverse( + config, + func::Const{typeof(Enzyme.pmap)}, + ::Type{Const{Nothing}}, + tapes, + body::BodyTy, + count, + args::Vararg{Annotation,N}, +) where {BodyTy,N} + config2 = ReverseModeSplit{ + false, + false, + EnzymeRules.width(config), + EnzymeRules.overwritten(config)[2:end], + InlineABI, + }() + fwd_thunk, rev_thunk = autodiff_thunk( + config2, BodyTy, Const, typeof(count), map(typeof, args)... + ) Enzyme.pmap(pmap_rev, count.val, tapes, rev_thunk, body, args...) @@ -291,16 +355,14 @@ function EnzymeRules.reverse(config, func::Const{typeof(Enzyme.pmap)}, ::Type{Co Libc.free(tapes) end - return ntuple(Val(2+length(args))) do _ + return ntuple(Val(2 + length(args))) do _ Base.@_inline_meta nothing end end - - # From LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.apple.darwin14/share/julia/stdlib/v1.10/LinearAlgebra/src/generic.jl:1110 -@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT, BT} +@inline function compute_lu_cache(cache_A::AT, b::BT) where {AT,BT} LinearAlgebra.require_one_based_indexing(cache_A, b) m, n = size(cache_A) @@ -323,8 +385,9 @@ end # y=inv(A) B # dA −= z y^T # dB += z, where z = inv(A^T) dy -function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT}) where {RT, AT <: Array, BT <: Array} - +function EnzymeRules.augmented_primal( + config, func::Const{typeof(\)}, ::Type{RT}, A::Annotation{AT}, b::Annotation{BT} +) where {RT,AT<:Array,BT<:Array} cache_A = if EnzymeRules.overwritten(config)[2] copy(A.val) else @@ -362,33 +425,42 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT} nothing end -@static if VERSION < v"1.8.0" - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT}, - LinearAlgebra.QRCompactWY{eltype(AT), AT} - } -else - UT = Union{ - LinearAlgebra.Diagonal{eltype(AT), BT}, - LinearAlgebra.LowerTriangular{eltype(AT), AT}, - LinearAlgebra.UpperTriangular{eltype(AT), AT}, - LinearAlgebra.LU{eltype(AT), AT, Vector{Int}}, - LinearAlgebra.QRPivoted{eltype(AT), AT, BT, Vector{Int}} - } -end - - cache = NamedTuple{(Symbol("1"),Symbol("2"), Symbol("3"), Symbol("4")), Tuple{typeof(res), typeof(dres), UT, typeof(cache_b)}}( - (cache_res, dres, cache_A, cache_b) + @static if VERSION < v"1.8.0" + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT}, + LinearAlgebra.QRCompactWY{eltype(AT),AT}, + } + else + UT = Union{ + LinearAlgebra.Diagonal{eltype(AT),BT}, + LinearAlgebra.LowerTriangular{eltype(AT),AT}, + LinearAlgebra.UpperTriangular{eltype(AT),AT}, + LinearAlgebra.LU{eltype(AT),AT,Vector{Int}}, + LinearAlgebra.QRPivoted{eltype(AT),AT,BT,Vector{Int}}, + } + end + + cache = NamedTuple{ + (Symbol("1"), Symbol("2"), Symbol("3"), Symbol("4")), + Tuple{typeof(res),typeof(dres),UT,typeof(cache_b)}, + }((cache_res, dres, cache_A, cache_b)) + + return EnzymeRules.AugmentedReturn{typeof(retres),typeof(dres),typeof(cache)}( + retres, dres, cache ) - - return EnzymeRules.AugmentedReturn{typeof(retres), typeof(dres), typeof(cache)}(retres, dres, cache) end -function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, A::Annotation{<:Array}, b::Annotation{<:Array}) where RT - +function EnzymeRules.reverse( + config, + func::Const{typeof(\)}, + ::Type{RT}, + cache, + A::Annotation{<:Array}, + b::Annotation{<:Array}, +) where {RT} y, dys, cache_A, cache_b = cache if !EnzymeRules.overwritten(config)[3] @@ -444,14 +516,11 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache, dy .= eltype(dy)(0) end - return (nothing,nothing) + return (nothing, nothing) end const EnzymeTriangulars = Union{ - UpperTriangular, - LowerTriangular, - UnitUpperTriangular, - UnitLowerTriangular + UpperTriangular,LowerTriangular,UnitUpperTriangular,UnitLowerTriangular } function EnzymeRules.augmented_primal( @@ -460,8 +529,8 @@ function EnzymeRules.augmented_primal( ::Type{RT}, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {RT, YT <: Array, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {RT,YT<:Array,AT<:EnzymeTriangulars,BT<:Array} cache_Y = EnzymeRules.overwritten(config)[1] ? copy(Y.val) : Y.val cache_A = EnzymeRules.overwritten(config)[2] ? copy(A.val) : A.val cache_A = compute_lu_cache(cache_A, B.val) @@ -469,8 +538,9 @@ function EnzymeRules.augmented_primal( primal = EnzymeRules.needs_primal(config) ? Y.val : nothing shadow = EnzymeRules.needs_shadow(config) ? Y.dval : nothing func.val(Y.val, A.val, B.val) - return EnzymeRules.AugmentedReturn{typeof(primal), typeof(shadow), Any}( - primal, shadow, (cache_Y, cache_A, cache_B)) + return EnzymeRules.AugmentedReturn{typeof(primal),typeof(shadow),Any}( + primal, shadow, (cache_Y, cache_A, cache_B) + ) end function EnzymeRules.reverse( @@ -480,8 +550,8 @@ function EnzymeRules.reverse( cache, Y::Annotation{YT}, A::Annotation{AT}, - B::Annotation{BT} -) where {YT <: Array, RT, AT <: EnzymeTriangulars, BT <: Array} + B::Annotation{BT}, +) where {YT<:Array,RT,AT<:EnzymeTriangulars,BT<:Array} if !isa(Y, Const) (cache_Yout, cache_A, cache_B) = cache for b in 1:EnzymeRules.width(config) @@ -507,62 +577,75 @@ _zero_unused_elements!(X, ::UnitUpperTriangular) = triu!(X, 1) _zero_unused_elements!(X, ::UnitLowerTriangular) = tril!(X, -1) @static if VERSION >= v"1.7-" -# Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) -function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - primal = if EnzymeRules.needs_primal(config) - out.val - else - nothing - end - shadow = if EnzymeRules.needs_shadow(config) - out.dval - else - nothing - end - func.val(out.val, inp.val) - return EnzymeRules.AugmentedReturn(primal, shadow, nothing) -end - -function EnzymeRules.reverse(config, func::Const{typeof(Base.hvcat_fill!)}, ::Type{RT}, _, out::Annotation{AT}, inp::Annotation{BT}) where {RT, AT <: Array, BT <: Tuple} - nr, nc = size(out.val,1), size(out.val,2) - for b in 1:EnzymeRules.width(config) - da = if EnzymeRules.width(config) == 1 + # Force a rule around hvcat_fill as it is type unstable if the tuple is not of the same type (e.g., int, float, int, float) + function EnzymeRules.augmented_primal( + config, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + out::Annotation{AT}, + inp::Annotation{BT}, + ) where {RT,AT<:Array,BT<:Tuple} + primal = if EnzymeRules.needs_primal(config) + out.val + else + nothing + end + shadow = if EnzymeRules.needs_shadow(config) out.dval else - out.dval[b] + nothing end - i = 1 - j = 1 - if (typeof(inp) <: Active) - dinp = ntuple(Val(length(inp.val))) do k - Base.@_inline_meta - res = da[i, j] - da[i, j] = 0 - j += 1 - if j == nc+1 - i += 1 - j = 1 - end - T = BT.parameters[k] - if T <: AbstractFloat - T(res) - else - T(0) + func.val(out.val, inp.val) + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) + end + + function EnzymeRules.reverse( + config, + func::Const{typeof(Base.hvcat_fill!)}, + ::Type{RT}, + _, + out::Annotation{AT}, + inp::Annotation{BT}, + ) where {RT,AT<:Array,BT<:Tuple} + nr, nc = size(out.val, 1), size(out.val, 2) + for b in 1:EnzymeRules.width(config) + da = if EnzymeRules.width(config) == 1 + out.dval + else + out.dval[b] + end + i = 1 + j = 1 + if (typeof(inp) <: Active) + dinp = ntuple(Val(length(inp.val))) do k + Base.@_inline_meta + res = da[i, j] + da[i, j] = 0 + j += 1 + if j == nc + 1 + i += 1 + j = 1 + end + T = BT.parameters[k] + if T <: AbstractFloat + T(res) + else + T(0) + end end + return (nothing, dinp)::Tuple{Nothing,BT} end - return (nothing, dinp)::Tuple{Nothing, BT} end + return (nothing, nothing) end - return (nothing, nothing) -end end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -576,11 +659,11 @@ function EnzymeRules.forward( end function EnzymeRules.forward( - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, - xs::BatchDuplicated{T, N}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}, N} + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,BatchDuplicatedNoNeed,BatchDuplicated}}, + xs::BatchDuplicated{T,N}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat},N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -595,14 +678,13 @@ function EnzymeRules.forward( end end - function EnzymeRules.augmented_primal( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated{T}; - kwargs... - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -620,13 +702,13 @@ function EnzymeRules.augmented_primal( end function EnzymeRules.reverse( - config::EnzymeRules.ConfigWidth{1}, - ::Const{typeof(sort!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - tape, - xs::Duplicated{T}; - kwargs..., - ) where {T <: AbstractArray{<:AbstractFloat}} + config::EnzymeRules.ConfigWidth{1}, + ::Const{typeof(sort!)}, + RT::Type{<:Union{Const,DuplicatedNoNeed,Duplicated}}, + tape, + xs::Duplicated{T}; + kwargs..., +) where {T<:AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] @@ -694,11 +776,7 @@ end # B(out) = inv(A) B(in) # dB(out) = inv(A) [ dB(in) - dA B(out) ] function EnzymeRules.forward( - func::Const{typeof(ldiv!)}, - RT::Type, - fact::Annotation{<:Cholesky}, - B; - kwargs... + func::Const{typeof(ldiv!)}, RT::Type, fact::Annotation{<:Cholesky}, B; kwargs... ) if isa(B, Const) @assert (RT <: Const) @@ -708,11 +786,15 @@ function EnzymeRules.forward( @assert !isa(B, Const) - retval = if !isa(fact, Const) || (RT <: Const) || (RT <: Duplicated) || (RT <: BatchDuplicated) - func.val(fact.val, B.val; kwargs...) - else - nothing - end + retval = + if !isa(fact, Const) || + (RT <: Const) || + (RT <: Duplicated) || + (RT <: BatchDuplicated) + func.val(fact.val, B.val; kwargs...) + else + nothing + end dretvals = ntuple(Val(N)) do b Base.@_inline_meta @@ -724,13 +806,12 @@ function EnzymeRules.forward( end if !isa(fact, Const) - dfact = if N == 1 fact.dval else fact.dval[b] end - + tmp = dfact.U * retval mul!(dB, dfact.L, tmp, -1, 1) end @@ -757,8 +838,8 @@ function EnzymeRules.augmented_primal( func::Const{typeof(cholesky)}, RT::Type, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) - + kwargs..., +) fact = if EnzymeRules.needs_primal(config) || !(RT <: Const) cholesky(A.val; kwargs...) else @@ -791,7 +872,8 @@ function EnzymeRules.reverse( RT::Type, cache, A::Annotation{<:Union{Matrix,LinearAlgebra.RealHermSym{<:Real,<:Matrix}}}; - kwargs...) + kwargs..., +) if !(RT <: Const) && !isa(A, Const) fact, dfact = cache dAs = EnzymeRules.width(config) == 1 ? (A.dval,) : A.dval @@ -845,13 +927,14 @@ function _realifydiag!(A) end function EnzymeRules.augmented_primal( - config, - func::Const{typeof(ldiv!)}, - RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}}, - - A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... + config, + func::Const{typeof(ldiv!)}, + RT::Type{ + <:Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated} + }, + A::Annotation{<:Cholesky}, + B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs..., ) cache_B = if !isa(A, Const) && !isa(B, Const) EnzymeRules.overwritten(config)[3] ? copy(B.val) : B.val @@ -877,10 +960,10 @@ function EnzymeRules.reverse( dret, cache, A::Annotation{<:Cholesky}, - B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}; - kwargs... + B::Union{Const,DuplicatedNoNeed,Duplicated,BatchDuplicatedNoNeed,BatchDuplicated}; + kwargs..., ) - if !isa(B, Const) + if !isa(B, Const) (cache_A, cache_B) = cache Y = B.val U = cache_A.U diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 9cf2de03fd..a2e7de178c 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -17,7 +17,7 @@ function sorterrfn(t, x) function lt(a, b) return a.a < b.a end - return first(sortperm(t, lt=lt)) * x + return first(sortperm(t; lt=lt)) * x end @testset "Sort rules" begin @@ -28,10 +28,12 @@ end end @test autodiff(Forward, f1, Duplicated(2.0, 1.0))[1] == 1 - @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=1.0, var"2"=2.0) + @test autodiff(Forward, f1, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=1.0, var"2"=2.0) @test autodiff(Reverse, f1, Active, Active(2.0))[1][1] == 1 @test autodiff(Forward, f1, Duplicated(4.0, 1.0))[1] == 0 - @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == (var"1"=0.0, var"2"=0.0) + @test autodiff(Forward, f1, BatchDuplicated(4.0, (1.0, 2.0)))[1] == + (var"1"=0.0, var"2"=0.0) @test autodiff(Reverse, f1, Active, Active(4.0))[1][1] == 0 function f2(x) @@ -41,10 +43,13 @@ end end @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 - @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) + @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == + (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 - dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + dd = Duplicated( + [TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)] + ) res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) @test res[1][2] ≈ 3 @@ -62,7 +67,13 @@ end b = Float64[11, 13] db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Duplicated{typeof(A)}, + Duplicated{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Duplicated(b, db)) @@ -79,7 +90,13 @@ end db = zero(b) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Const{typeof(A)}, + Duplicated{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db)) @@ -95,7 +112,13 @@ end dA = zero(A) - forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)}) + forward, pullback = Enzyme.autodiff_thunk( + ReverseSplitNoPrimal, + Const{typeof(\)}, + Duplicated, + Duplicated{typeof(A)}, + Const{typeof(b)}, + ) tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b)) @@ -111,88 +134,98 @@ end end @static if VERSION > v"1.8" -@testset "Cholesky" begin - function cholesky_testfunction_symmetric(A, b, x1, x2) - C1 = cholesky(A * A') # test factorization without wrapper - C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - function cholesky_testfunction_hermitian(A, b, x1, x2) - C1 = cholesky(A * adjoint(A)) # test factorization without wrapper - C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper - x1 .= C1 \ b # test linear solve with factorization object without wrapper - x2 .= C2 \ b # test linear solve with factorization object with wrapper - return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself - end - @testset for (TE, testfunction) in ( - Float64 => cholesky_testfunction_symmetric, - Float64 => cholesky_testfunction_hermitian - ) - @testset for TA in (Const, Duplicated), - Tb in (Const, Duplicated), - Tx1 in (Const, Duplicated), - Tx2 in (Const, Duplicated) - A = rand(TE, 5, 5) - b = rand(TE, 5) - x1 = rand(TE, 5) - x2 = rand(TE, 5) - # ishermitian(A * adjoint(A)) || continue - @testset for Tret in (Const, Duplicated) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) - end - @testset for Tret in (Const, Active) - are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue - test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + @testset "Cholesky" begin + function cholesky_testfunction_symmetric(A, b, x1, x2) + C1 = cholesky(A * A') # test factorization without wrapper + C2 = cholesky(Symmetric(A * A')) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + function cholesky_testfunction_hermitian(A, b, x1, x2) + C1 = cholesky(A * adjoint(A)) # test factorization without wrapper + C2 = cholesky(Hermitian(A * adjoint(A))) # test factorization with wrapper + x1 .= C1 \ b # test linear solve with factorization object without wrapper + x2 .= C2 \ b # test linear solve with factorization object with wrapper + return sum(abs2, C1.L * C1.U) + sum(abs2, C2.L * C2.U) # test factorization itself + end + @testset for (TE, testfunction) in ( + Float64 => cholesky_testfunction_symmetric, + Float64 => cholesky_testfunction_hermitian, + ) + @testset for TA in (Const, Duplicated), + Tb in (Const, Duplicated), + Tx1 in (Const, Duplicated), + Tx2 in (Const, Duplicated) + + A = rand(TE, 5, 5) + b = rand(TE, 5) + x1 = rand(TE, 5) + x2 = rand(TE, 5) + # ishermitian(A * adjoint(A)) || continue + @testset for Tret in (Const, Duplicated) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_forward(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end + @testset for Tret in (Const, Active) + are_activities_compatible(Tret, TA, Tb, Tx1, Tx2) || continue + test_reverse(testfunction, Tret, (A, TA), (b, Tb), (x1, Tx1), (x2, Tx2)) + end end end end -end -@testset "Linear solve for triangular matrices" begin - @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), - TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3)) - n = sizeB[1] - M = rand(TE, n, n) - B = rand(TE, sizeB...) - Y = zeros(TE, sizeB...) - A = T(M) - @testset "test through constructor" begin - _A = T(A) - function f!(Y, A, B, ::T) where T - ldiv!(Y, T(A), B) - return nothing + @testset "Linear solve for triangular matrices" begin + @testset for T in ( + UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular + ), + TE in (Float64, ComplexF64), + sizeB in ((3,), (3, 3)) + + n = sizeB[1] + M = rand(TE, n, n) + B = rand(TE, sizeB...) + Y = zeros(TE, sizeB...) + A = T(M) + @testset "test through constructor" begin + _A = T(A) + function f!(Y, A, B, ::T) where {T} + ldiv!(Y, T(A), B) + return nothing + end + for TY in (Const, Duplicated, BatchDuplicated), + TM in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Const, TY, TM, TB) || continue + test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + end end - for TY in (Const, Duplicated, BatchDuplicated), - TM in (Const, Duplicated, BatchDuplicated), - TB in (Const, Duplicated, BatchDuplicated) - are_activities_compatible(Const, TY, TM, TB) || continue - test_reverse(f!, Const, (Y, TY), (M, TM), (B, TB), (_A, Const)) + @testset "test through `Adjoint` wrapper (regression test for #1306)" begin + # Test that we get the same derivative for `M` as for the adjoint of its + # (materialized) transpose. It's the same matrix, but represented differently + function f!(Y, A, B) + ldiv!(Y, A, B) + return nothing + end + A1 = T(M) + A2 = T(conj(permutedims(M))') + dA1 = make_zero(A1) + dA2 = make_zero(A2) + dB1 = make_zero(B) + dB2 = make_zero(B) + dY1 = rand(TE, sizeB...) + dY2 = copy(dY1) + autodiff( + Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1) + ) + autodiff( + Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2) + ) + @test dA1.data ≈ dA2.data + @test dB1 ≈ dB2 end end - @testset "test through `Adjoint` wrapper (regression test for #1306)" begin - # Test that we get the same derivative for `M` as for the adjoint of its - # (materialized) transpose. It's the same matrix, but represented differently - function f!(Y, A, B) - ldiv!(Y, A, B) - return nothing - end - A1 = T(M) - A2 = T(conj(permutedims(M))') - dA1 = make_zero(A1) - dA2 = make_zero(A2) - dB1 = make_zero(B) - dB2 = make_zero(B) - dY1 = rand(TE, sizeB...) - dY2 = copy(dY1) - autodiff(Reverse, f!, Duplicated(Y, dY1), Duplicated(A1, dA1), Duplicated(B, dB1)) - autodiff(Reverse, f!, Duplicated(Y, dY2), Duplicated(A2, dA2), Duplicated(B, dB2)) - @test dA1.data ≈ dA2.data - @test dB1 ≈ dB2 - end end end -end end # InternalRules