diff --git a/src/make_zero.jl b/src/make_zero.jl index 3634f6e2cd..dbb320706b 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -1,46 +1,3 @@ -const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} - -@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat} - return Base.zero(prev)::FT -end - -@inline function EnzymeCore.make_zero( - ::Type{FT}, - @nospecialize(seen::IdDict), - prev::FT, - @nospecialize(_::Val{copy_if_inactive}=Val(false)), -) where {FT<:_RealOrComplexFloat,copy_if_inactive} - return EnzymeCore.make_zero(prev)::FT -end - -@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N} - # convert because Base.zero may return different eltype when FT is not concrete - return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N} -end - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - @nospecialize(_::Val{copy_if_inactive}=Val(false)), -) where {FT<:_RealOrComplexFloat,N,copy_if_inactive} - if haskey(seen, prev) - return seen[prev]::Array{FT,N} - end - newa = EnzymeCore.make_zero(prev) - seen[prev] = newa - return newa::Array{FT,N} -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -) where {RT,copy_if_inactive} - isleaftype(_) = false - isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true - f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive)) - return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT -end - recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T @inline function recursive_map( @@ -59,24 +16,6 @@ recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T return _recursive_map(RT, f, seen, xs, Val(copy_if_inactive), isleaftype)::RT end -@inline function _recursive_map( - ::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args... -) where {RT<:Array,F,N} - if haskey(seen, xs) - return seen[xs]::RT - end - y = RT(undef, size(first(xs))) - seen[xs] = y - for I in eachindex(xs...) - if all(x -> isassigned(x, I), xs) - xIs = ntuple(j -> xs[j][I], N) - ST = Core.Typeof(first(xIs)) - @inbounds y[I] = recursive_map(ST, f, seen, xIs, args...) - end - end - return y -end - @inline function _recursive_map( ::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args... ) where {RT,F,N} @@ -127,6 +66,103 @@ end return y end +@inline function _recursive_map( + ::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args... +) where {RT<:Array,F,N} + if haskey(seen, xs) + return seen[xs]::RT + end + y = RT(undef, size(first(xs))) + seen[xs] = y + for I in eachindex(xs...) + if all(x -> isassigned(x, I), xs) + xIs = ntuple(j -> xs[j][I], N) + ST = Core.Typeof(first(xIs)) + @inbounds y[I] = recursive_map(ST, f, seen, xIs, args...) + end + end + return y +end + +@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T} + return recursive_map!(f, y, Base.IdSet(), xs)::Nothing +end + +@inline function recursive_map!( + f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false) +) where {F,T,N,L} + if guaranteed_const_nongen(T, nothing) + return nothing + elseif isleaftype(T) + # If there exist T such that isleaftype(T) and T has mutable content that is not + # guaranteed const, including mutables nested inside immutables like Tuple{Vector}, + # then f must have a corresponding mutating method: + f(y, xs...) + return nothing + end + return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing +end + +@inline function _recursive_map!( + f::F, y::T, seen, xs::NTuple{N,T}, isleaftype +) where {F,T,N} + if y in seen + return nothing + end + @assert !Base.isabstracttype(T) + @assert Base.isconcretetype(T) + nf = fieldcount(T) + if nf == 0 + return nothing + end + push!(seen, y) + for i = 1:nf + if isdefined(y, i) && all(x -> isdefined(x, i), xs) + yi = getfield(y, i) + xis = ntuple(j -> getfield(xs[j], i), N) + SBT = Core.Typeof(yi) + activitystate = active_reg_inner(SBT, (), nothing, Val(false)) + if activitystate == AnyState + continue + elseif activitystate == DupState + recursive_map!(f, yi, seen, xis, isleaftype) + else + yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype) + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi) + else + setfield!(y, i, yi) + end + end + end + end + return nothing +end + +@inline function _recursive_map!( + f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype +) where {F,T,M,N} + if y in seen + return nothing + end + push!(seen, y) + for I in eachindex(y, xs...) + if isassigned(y, I) && all(x -> isassigned(x, I), xs) + yvalue = y[I] + xvalues = ntuple(j -> xs[j][I], N) + SBT = Core.Typeof(yvalue) + if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# + @inbounds y[I] = recursive_map_immutable!( + f, yvalue, seen, xvalues, isleaftype + ) + else + recursive_map!(f, yvalue, seen, xvalues, isleaftype) + end + end + end + return nothing +end + @inline function recursive_map_immutable!(f::F, y::T, xs::T...) where {F,T} return recursive_map_immutable!(f, y, Base.IdSet(), xs)::T end @@ -185,20 +221,47 @@ end return newy end -@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N} - fill!(prev, zero(T)) - return nothing +const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}} + +@inline function EnzymeCore.make_zero( + ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) +) where {RT,copy_if_inactive} + isleaftype(_) = false + isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true + f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive)) + return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT end -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, seen::Base.IdSet, -) where {T<:_RealOrComplexFloat,N} - if prev in seen - return nothing +@inline function EnzymeCore.make_zero( + ::Type{FT}, + @nospecialize(seen::IdDict), + prev::FT, + @nospecialize(_::Val{copy_if_inactive}=Val(false)), +) where {FT<:_RealOrComplexFloat,copy_if_inactive} + return EnzymeCore.make_zero(prev)::FT +end + +@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat} + return Base.zero(prev)::FT +end + +@inline function EnzymeCore.make_zero( + ::Type{Array{FT,N}}, + seen::IdDict, + prev::Array{FT,N}, + @nospecialize(_::Val{copy_if_inactive}=Val(false)), +) where {FT<:_RealOrComplexFloat,N,copy_if_inactive} + if haskey(seen, prev) + return seen[prev]::Array{FT,N} end - push!(seen, prev) - EnzymeCore.make_zero!(prev) - return nothing + newa = EnzymeCore.make_zero(prev) + seen[prev] = newa + return newa::Array{FT,N} +end + +@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N} + # convert because Base.zero may return different eltype when FT is not concrete + return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N} end @inline function EnzymeCore.make_zero!(prev, seen::Base.IdSet=Base.IdSet()) @@ -213,81 +276,18 @@ end return recursive_map!(f, prev, seen, (prev,), isleaftype)::Nothing end -@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T} - return recursive_map!(f, y, Base.IdSet(), xs)::Nothing -end - -@inline function recursive_map!( - f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false) -) where {F,T,N,L} - if guaranteed_const_nongen(T, nothing) - return nothing - elseif isleaftype(T) - # If there exist T such that isleaftype(T) and T has mutable content that is not - # guaranteed const, including mutables nested inside immutables like Tuple{Vector}, - # then f must have a corresponding mutating method: - f(y, xs...) - return nothing - end - return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing -end - -@inline function _recursive_map!( - f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype -) where {F,T,M,N} - if y in seen +@inline function EnzymeCore.make_zero!( + prev::Array{T,N}, seen::Base.IdSet, +) where {T<:_RealOrComplexFloat,N} + if prev in seen return nothing end - push!(seen, y) - for I in eachindex(y, xs...) - if isassigned(y, I) && all(x -> isassigned(x, I), xs) - yvalue = y[I] - xvalues = ntuple(j -> xs[j][I], N) - SBT = Core.Typeof(yvalue) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - @inbounds y[I] = recursive_map_immutable!( - f, yvalue, seen, xvalues, isleaftype - ) - else - recursive_map!(f, yvalue, seen, xvalues, isleaftype) - end - end - end + push!(seen, prev) + EnzymeCore.make_zero!(prev) return nothing end -@inline function _recursive_map!( - f::F, y::T, seen, xs::NTuple{N,T}, isleaftype -) where {F,T,N} - if y in seen - return nothing - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - if nf == 0 - return nothing - end - push!(seen, y) - for i = 1:nf - if isdefined(y, i) && all(x -> isdefined(x, i), xs) - yi = getfield(y, i) - xis = ntuple(j -> getfield(xs[j], i), N) - SBT = Core.Typeof(yi) - activitystate = active_reg_inner(SBT, (), nothing, Val(false)) - if activitystate == AnyState - continue - elseif activitystate == DupState - recursive_map!(f, yi, seen, xis, isleaftype) - else - yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype) - if Base.isconst(T, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi) - else - setfield!(y, i, yi) - end - end - end - end +@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N} + fill!(prev, zero(T)) return nothing end