From 72bda996ec6e1cad6eec04b9d59c2f859f016696 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Wed, 30 Oct 2024 15:07:49 -0700 Subject: [PATCH 1/8] Add recursive_map and base make_zero(!) on it --- ext/EnzymeStaticArraysExt.jl | 14 +- lib/EnzymeCore/src/EnzymeCore.jl | 82 +++- src/compiler.jl | 3 +- src/make_zero.jl | 538 ------------------------ src/recursive_map.jl | 501 ++++++++++++++++++++++ test/abi.jl | 13 - test/recursive_map.jl | 690 +++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 8 files changed, 1277 insertions(+), 565 deletions(-) delete mode 100644 src/make_zero.jl create mode 100644 src/recursive_map.jl create mode 100644 test/recursive_map.jl diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index c2639a4c99..a2d65a4653 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,11 +32,15 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray} - return Base.zero(x) -end -@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray} - return Base.zero(x) +# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct, +# but in case their dedicated `zero` and `fill!` methods are more efficient than +# `make_zero(!)`s generic recursion, we opt into treating them as leaves when they have +# isbits eltypes (non-isbits eltypes excluded as the dedicated `zero` and `fill!` methods +# don't support those). +@inline function Enzyme.EnzymeCore.isvectortype( + ::Type{<:Union{SArray{S,T},MArray{S,T}}} +) where {S,T} + return isbitstype(T) && Enzyme.Compiler.RecursiveMap.isscalartype(T) end end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index c751aaac38..298a45d5f0 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -501,28 +501,96 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ + make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. + +Extending this method for custom types is rarely needed. For new plain array types like GPU +arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type implements +`Base.zero`. """ function make_zero end """ - make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing + make_zero!(val::T, seen::IdDict=IdDict())::Nothing + +Recursively set a variables differentiable fields to zero. Only applicable for types `T` +that are mutable or hold all differentiable values in mutable containers (e.g., +`Tuple{Vector{Float64}}`). -Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +Extending this method for custom types is rarely needed. For new plain mutable array types +like GPU arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type +implements `Base.zero` and `Base.fill!`. """ function make_zero! end """ - make_zero(prev::T) + isvectortype(::Type{T})::Bool + +Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref) +and [`make_zero!`](@ref) recurse through an object. + +By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or +`T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`. + +A new plain array type, for example a GPU array, may extend this as follows: + +```julia +@inline EnzymeCore.isvectortype(::Type{<:GPUArray{U}}) where {U} = isscalartype(U) +``` + +Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not +feasible, an alternative is to add methods `EnzymeCore.make_zero(arr::T)::T` and, if +mutable, `EnzymeCore.make_zero!(arr::T)::Nothing`; such methods will also be picked up by +recursive calls.) -Helper function to recursively make zero. +Such extensions are mostly relevant for the lowest-level of abstraction of memory at which +vector space operations like addition and scalar multiplication are supported, the +prototypical case being `Array`. Regular Julia structs with vector space-like semantics +should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act +directly on their backing arrays, just like how Enzyme treats them when differentiating. For +example, structured matrix wrappers and sparse array types that are backed by `Array` should +not extend `isvectortype`. + +See also [`isscalartype`](@ref). """ -@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} - make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) -end +function isvectortype end + +""" + isscalartype(::Type{T})::Bool + +Trait defining a subset of [`isvectortype`](@ref) types that should not be considered +composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero +values of the type in-place. For example, `BigFloat` is a mutable type but does not support +in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that +`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] + +By default, `isscalartype(T) == true` for `T <: AbstractFloat` and +`T <: Complex{<:AbstractFloat}`. + +A hypothetical new real number type with Enzyme support should in most cases simply subtype +`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate, +the function can be extended as follows: + +```julia +@inline EnzymeCore.isscalartype(::Type{<:NewReal}) = true +@inline EnzymeCore.isscalartype(::Type{<:Complex{<:NewReal}}) = true +``` + +In either case, the type should implement `Base.zero`. (If this is not feasible, an +alternative is to add a method `EnzymeCore.make_zero(x::T)::T`; such a method will also be +picked up by recursive calls.) + +See also [`isvectortype`](@ref). + +[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is +mentioned here only to illustrate that it would be inappropriate to use traits like +`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing, +demonstrating the need for a dedicated `isscalartype` trait. +""" +function isscalartype end function tape_type end diff --git a/src/compiler.jl b/src/compiler.jl index 22f0c21b13..a246322c39 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1132,8 +1132,6 @@ struct Tape{TapeTy,ShadowTy,ResT} shadow_return::ShadowTy end -include("make_zero.jl") - function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world) funcspec = my_methodinstance(typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) @@ -7610,6 +7608,7 @@ end end # Recursively return x + f(y), where y is active, otherwise x +include("recursive_map.jl") @inline function recursive_add( x::T, diff --git a/src/make_zero.jl b/src/make_zero.jl deleted file mode 100644 index f2fd055c61..0000000000 --- a/src/make_zero.jl +++ /dev/null @@ -1,538 +0,0 @@ - -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, FT}, -)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, Complex{FT}}, -)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -end - - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, FT}}, - seen::IdDict, - prev::GenericMemory{kind, FT}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, Complex{FT}}}, - seen::IdDict, - prev::GenericMemory{kind, Complex{FT}}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:GenericMemory} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - return NamedTuple{A,RT}(EnzymeCore.make_zero(RT, seen, RT(prev), Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = Base.Ref( - EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)), - ) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - if Base.isconst(RT, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) - else - setfield!(y, i, xi) - end - end - end - return y - end - - if nf == 0 - return prev - end - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - zero(T) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[i], seen) - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - make_zero_immutable!(prev[a[i]], seen) - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev - end - @assert !ismutable(prev) - - RT = Core.Typeof(prev) - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if active_reg_inner(ST, (), nothing, Val(true)) == ActiveState #=justActive=# - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(T) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - T[] = zero(Complex{T}) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(T)) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - fill!(prev, zero(Complex{T})) - nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - fill!(prev, zero(T)) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - fill!(prev, zero(Complex{T})) - nothing -end -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - end - end - nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, -)::Nothing where {T<:AbstractFloat, kind} - EnzymeCore.make_zero!(prev, nothing) - nothing -end - -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - @inbounds prev[I] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - end - end - nothing -end -end - - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - - pv = prev[] - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev[] = make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - pv = prev.contents - T = Core.Typeof(pv) - if guaranteed_const_nongen(T, nothing) - return - end - if in(seen, prev) - return - end - push!(seen, prev) - SBT = Core.Typeof(pv) - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - prev.contents = EnzymeCore.make_zero_immutable!(pv, seen) - nothing - else - EnzymeCore.make_zero!(pv, seen) - nothing - end - nothing -end - -@inline function EnzymeCore.make_zero!( - prev::T, - seen::S = Base.IdSet{Any}(), -)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return - end - if in(prev, seen) - return - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - - - if nf == 0 - return - end - - push!(seen, prev) - - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - if guaranteed_const_nongen(SBT, nothing) - continue - end - if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=# - setfield!(prev, i, make_zero_immutable!(xi, seen)) - nothing - else - EnzymeCore.make_zero!(xi, seen) - nothing - end - end - end - return -end diff --git a/src/recursive_map.jl b/src/recursive_map.jl new file mode 100644 index 0000000000..aa6ad2ce50 --- /dev/null +++ b/src/recursive_map.jl @@ -0,0 +1,501 @@ +module RecursiveMap + +using EnzymeCore: EnzymeCore, isvectortype, isscalartype +using ..Compiler: ActiveState, active_reg_inner, guaranteed_const_nongen, splatnew + +""" + y = recursive_map( + [seen::IdDict,] + f, + xs::NTuple{N,T}, + ::Val{copy_if_inactive}=Val(false), + )::T where {T,N,copy_if_inactive} + newy = recursive_map( + [seen::IdDict,] + f, + (; y, xs)::@NamedTuple{y::T,xs::NTuple{N,T}}, + ::Val{copy_if_inactive}=Val(false), + )::T where {T,N,copy_if_inactive} + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `N` objects `xs = (x1::T, x2::T, ..., xN::T)` of the same type, mapping the +function `f` over every differentiable value encountered and building a new object `y::T` +from the resulting values `yi = f(x1i, ..., xNi)`. + +The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered +differentiable leaf nodes at which recursion terminates and `f` is invoked. See the +docstring for that and the related [`EnzymeCore.isscalartype`](@ref) for more information. + +An existing object `y::T` can be passed by replacing the tuple `xs` with a NamedTuple +`(; y, xs)`, in which case `y` is updated "partially-in-place": any parts of `y` that are +mutable or non-differentiable are reused in the returned object `newy`, while immutable +differentiable parts are handled out-of-place as if `y` were not passed (this can be seen as +a recursive generalization of the BangBang.jl idiom). If `y` itself is mutable, it is +modified in-place and returned, such that `newy === y`. + +The recursion and mapping operates on the structure of `T` as defined by struct fields and +plain array elements, not on the values provided through an iteration or array interface. +For example, given a structured matrix wrapper or sparse array type, this function recurses +into the struct type and the plain arrays held within, rather than operating on the array +that the type notionally represents. + +# Arguments + +* `seen::IdDict` (optional): Dictionary for tracking object identity as needed to construct + `y` such that its internal graph of object references is identical to that of the `xs`, + including cycles (i.e., recursive substructures) and multiple paths to the same objects. + If not provided, an `IdDict` will be allocated internally if required. + +* `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf node in `y`, + that is, `yi = f(x1i::U, ..., xNi::U)::U`. The function `f` must be applicable to the type + of every leaf node, and must return a value of the same type as its arguments. + + When an existing object `y` is passed and contains leaf nodes of a non-isbits non-scalar + type `U`, `f` should also have a partially-in-place method + `newyi === f(yi::U, x1i::U, ..., xNi::U)::U` that modifies and reuses any mutable parts of + `yi`; in particular, if `U` is a mutable type, this method should return `newyi === yi`. + If a non-isbits type `U` should always be handled using the out-of-place signature, extend + [`EnzymeCore.isscalartype`](@ref) such that `isscalartype(U) == true`. + + See [`EnzymeCore.isvectortype`](@ref) and [`EnzymeCore.isscalartype`)(@ref) for more + details about leaf types and scalar types. + +* `xs::NTuple{N,T}` or `(; y, xs)::@NamedTuple{y::T,xs::NTuple{N,T}}`: Tuple of `N` objects + of the same type `T`, or NamedTuple combining this Tuple with an existing object `y::T` + that can be partially or fully reused in the returned object. + + The first object `x1 = first(xs)` is the reference for graph structure and + non-differentiable values when constructing the returned object. In particular: + * When `y` is not passed, the returned object takes any non-differentiable parts from + `x1`. (When `y` is passed, its non-differentiable parts are reused in the returned + object, unless they are not initialized, in which case they are taken from `x1`.) + * The graph of object references in `x1` is the one which is reproduced in the returned + object. For each instance of multiple paths and cycles within `x1`, the same structure + must be present in the other objects `x2, ..., xN`, otherwise the corresponding values + in `y` would not be uniquely defined. However, `x2, ..., xN` may contain multiple paths + or cycles that are not present in `x1`; these do not affect the structure of `y`. + * If any values within `x1` are not initialized (that is, struct fields are undefined or + array elements are unassigned), they are left uninitialized in the returned object. If + any such values are mutable and `y` is passed, the corresponding value in `y` must not + already be initialized, since initialized values cannot be nulled. Conversely, for every + value in `x1` that is initialized, the corresponding values in `x2, ..., xN` must also + be initialized, such that the corresponding value of `y` can be computed (however, + values in `x2, ..., xN` can be initialized while the corresponding value in `x1` is not; + such values are ignored.) + +* `Val(copy_if_inactive)::Val{::Bool}` (optional): When a non-differentiable part of `x1` is + included in the returned object, either because an object `y` is not passed or this part + of `y` is not initialized, `copy_if_inactive` determines how it is included: if + `copy_if_inactive == false`, it is shared as `yi = x1i`; if `copy_if_inactive == true`, it + is deep-copied, more-or-less as `yi = deepcopy(x1i)` (the difference is that when `x1` has + several non-differentiable parts, object identity is tracked across the multiple + deep-copies such that the object reference graph is reproduced also within the inactive + parts.) +""" +function recursive_map end + +## type aliases, for generic handling of out-of-place and partially-in-place recursive_map +const XTup{N,T} = NTuple{N,T} +const YXTup{N,T} = @NamedTuple{y::T,xs::XTup{N,T}} +const XTupOrYXTup{N,T} = Union{XTup{N,T},YXTup{N,T}} + +@inline xtup(xs::XTup) = xs +@inline xtup((; xs)::YXTup) = xs + +@static if VERSION < v"1.11-" + const Arraylike{U} = Array{U} +else + const Arraylike{U} = Union{Array{U},GenericMemory{kind,U} where {kind}} +end + +## main entry point +@inline function recursive_map( + f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false) +) where {F,N,T} + # determine whether or not an IdDict is needed for this T + if isbitstype(T) || ( + guaranteed_const_nongen(T, nothing) && !needscopy(yxs, copy_if_inactive) + ) + y = recursive_map(nothing, f, yxs, copy_if_inactive) + else + y = recursive_map(IdDict(), f, yxs, copy_if_inactive) + end + return y::T +end + +## recursive methods +@inline function recursive_map( + seen::Union{Nothing,IdDict}, + f::F, + yxs::XTupOrYXTup{N,T}, + copy_if_inactive::Val=Val(false), +) where {F,N,T} + # determine whether to continue recursion, copy/share, or retrieve from cache + xs = xtup(yxs) + if guaranteed_const_nongen(T, nothing) + y = maybecopy(seen, yxs, copy_if_inactive) + elseif isbitstype(T) # no need to track identity or pass y in this branch + y = recursive_map_inner(nothing, f, xs, copy_if_inactive) + elseif hascache(seen, xs) + y = getcached(seen, xs) + else + y = recursive_map_inner(seen, f, yxs, copy_if_inactive) + end + return y::T +end + +@inline function recursive_map_inner( + seen, f::F, yxs::XTupOrYXTup{N,T}, args::Vararg{Any,M} +) where {F,N,T,M} + # forward to appropriate handler for leaf vs. mutable vs. immutable type + @assert !isabstracttype(T) + @assert isconcretetype(T) + if isvectortype(T) + y = recursive_map_leaf(seen, f, yxs) + elseif ismutabletype(T) + y = recursive_map_mutable(seen, f, yxs, args...) + else + y = recursive_map_immutable(seen, f, yxs, args...) + end + return y::T +end + +@inline function recursive_map_mutable( + seen, f::F, xs::XTup{N,T}, args::Vararg{Any,M} +) where {F,N,T,M} + # out-of-place mutable handler: construct y + @assert ismutabletype(T) + x1, xtail... = xs + nf = fieldcount(T) + if (!(T <: Arraylike)) && all(isbitstype, fieldtypes(T)) && all(i -> isdefined(x1, i), 1:nf) + # fast path when all fields are defined and all fieldtypes are bitstypes (the latter + # preventing circular references, which are incompatible with the fast path) + check_initialized(xtail, 1:nf) + fieldtup = ntuple(Val(nf)) do i + @inline + recursive_map_index(i, seen, f, xs, args...) + end + y = splatnew(T, fieldtup) + cache!(seen, y, xs) + else # handle both structs, arrays, and memory through generic helpers + y = _similar(x1) + cache!(seen, y, xs) + @inbounds for i in _eachindex(y, xs...) + if isinitialized(x1, i) + check_initialized(xtail, i) + yi = recursive_map_index(i, seen, f, xs, args...) + setvalue(y, i, yi) + end + end + end + return y::T +end + +@inline function recursive_map_mutable( + seen, f!!::F, (; y, xs)::YXTup{N,T}, args::Vararg{Any,M} +) where {F,N,T,M} + # in-place mutable handler: set/update values in y + @assert ismutabletype(T) + cache!(seen, y, xs) + x1, xtail... = xs + @inbounds for i in _eachindex(y, xs...) + # handle both structs, arrays, and memory through generic helpers + if isinitialized(x1, i) + check_initialized(xtail, i) + newyi = recursive_map_index(i, seen, f!!, (; y, xs), args...) + setvalue(y, i, newyi) + else + check_initialized((y,), i, false) + end + end + return y::T +end + +@inline function recursive_map_immutable( + seen, f::F, yxs::XTupOrYXTup{N,T}, args::Vararg{Any,M} +) where {F,N,T,M} + # immutable handler: construct y/newy + @assert !ismutabletype(T) + x1, xtail... = xtup(yxs) + nf = fieldcount(T) + if nf == 0 # nothing to do; assume inactive + y = maybecopy(seen, yxs, args...) + elseif isdefined(x1, nf) # fast path when all fields are defined + check_initialized(xtail, nf) + fieldtup = ntuple(Val(nf)) do i + @inline + recursive_map_index(i, seen, f, yxs, args...) + end + y = splatnew(T, fieldtup) + else + flds = Vector{Any}(undef, nf) + @inbounds for i in 1:nf + if isdefined(x1, i) + check_initialized(xtail, i) + flds[i] = recursive_map_index(i, seen, f, yxs, args...) + else + nf = i - 1 # rest of tail must be undefined values + break + end + end + y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf) + end + return y::T +end + +Base.@propagate_inbounds function recursive_map_index( + i, seen, f::F, xs::XTup, args::Vararg{Any,M} +) where {F,M} + # out-of-place recursive handler: extract value i from each of the xs; call + # recursive_map to obtain yi + xis = getvalues(xs, i) + yi = recursive_map(seen, f, xis, args...) + return yi::Core.Typeof(first(xis)) +end + +Base.@propagate_inbounds function recursive_map_index( + i, seen, f!!::F, (; y, xs)::YXTup, args::Vararg{Any,M} +) where {F,M} + # partially-in-place recursive handler: extract value i from each of the xs and, if + # initialized, from y; call recursive_map to obtain newyi + xis = getvalues(xs, i) + if isinitialized(y, i) + yi = getvalue(y, i) + newyi = recursive_map(seen, f!!, (; y=yi, xs=xis), args...) + else + newyi = recursive_map(seen, f!!, xis, args...) + end + return newyi::Core.Typeof(first(xis)) +end + +## leaf handlers +function recursive_map_leaf(seen, f::F, xs::XTup{N,T}) where {F,N,T} + # out-of-place + y = f(xs...) + if !isbitstype(T) + cache!(seen, y, xs) + end + return y::T +end + +function recursive_map_leaf(seen, f!!::F, (; y, xs)::YXTup{N,T}) where {F,N,T} + # partially-in-place + if isbitstype(T) || isscalartype(T) + newy = f!!(xs...) + else # !isbitstype(T) + newy = f!!(y, xs...) + if ismutabletype(T) + @assert newy === y + end + end + if !isbitstype(T) + cache!(seen, newy, xs) + end + return newy::T +end + +## helpers +# vector/scalar trait implementation +@inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) +@inline EnzymeCore.isvectortype(::Type{<:Arraylike{U}}) where {U} = isscalartype(U) + +@inline EnzymeCore.isscalartype(::Type{<:AbstractFloat}) = true +@inline EnzymeCore.isscalartype(::Type{<:Complex{<:AbstractFloat}}) = true +@inline EnzymeCore.isscalartype(::Type) = false + +# generic handling of mutable structs, arrays, and memory +@inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T +@inline _similar(x::T) where {T<:Arraylike} = similar(x)::T +@inline _eachindex(xs::T...) where {T} = 1:fieldcount(T) +@inline _eachindex(xs::Arraylike...) = eachindex(xs...) +@inline isinitialized(x, i) = isdefined(x, i) +Base.@propagate_inbounds isinitialized(x::Arraylike, i) = isassigned(x, i) +@inline getvalue(x, i) = getfield(x, i) +Base.@propagate_inbounds getvalue(x::Arraylike, i) = x[i] +@inline setvalue(x, i, v) = setfield_force!(x, i, v) +Base.@propagate_inbounds setvalue(x::Arraylike, i, v) = (x[i] = v; nothing) + +Base.@propagate_inbounds function getvalues(xs::XTup{N}, i) where {N} + return ntuple(Val(N)) do j + Base.@_propagate_inbounds_meta + getvalue(xs[j], i) + end +end + +@inline function setfield_force!(y::T, i, newyi) where {T} + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, newyi) + else + setfield!(y, i, newyi) + end + return nothing +end + +# generic inactive handler: sharing/copying (out-of-place) or leaving unchanged (in-place) +@inline maybecopy(_, (; y)::YXTup{N,T}, _) where {N,T} = y::T +@inline function maybecopy(seen, xs::XTup{N,T}, copy) where {N,T} + if needscopy(xs, copy) + y = Base.deepcopy_internal(first(xs), seen) + else + y = first(xs) + end + return y::T +end + +@inline needscopy(::YXTup, _) = false +@inline needscopy(::XTup{N,T}, ::Val{copy}) where {N,T,copy} = (copy && !isbitstype(T)) + +# validating cache handlers +@inline function cache!(seen::IdDict, y::T, xs::XTup{N,T}) where {N,T} + x1, xtail... = xs + seen[x1] = (y, xtail...) + return nothing +end + +@inline hascache(seen, xs::XTup) = haskey(seen, first(xs)) + +@inline function getcached(seen::IdDict, xs::XTup{N,T}) where {N,T} + x1, xtail... = xs + y, xtail_... = seen[x1]::XTup{N,T} + check_identical(xtail, xtail_) # check compatible topology + return y::T +end + +## in-place wrapper +""" + recursive_map!( + [seen::IdDict,] + f!!, + y::T, + xs::NTuple{N,T}, + isleaftype=Returns(false), + ::Val{copy_if_inactive}=Val(false), + )::Nothing where {T,N,copy_if_inactive} + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `N` objects `xs = (x1::T, x2::T, ..., xN::T)` of the same type, mapping the +function `f!!` over every differentiable value encountered and updating new mutable object +`y::T` in-place with the resulting values. + +This is a wrapper that calls +`recursive_map([seen,] f!!, (; y, xs), isleaftype, Val(copy_if_inactive))`, but only accepts +types `T` that are mutable (or, trivially, entirely non-differentiable), and enforces a +fully in-place update of `y`. See [`recursive_map`](@ref) for details. +""" +function recursive_map!(f!!::F, y::T, xs::XTup{N,T}, args::Vararg{Any,M}) where {F,N,T,M} + check_notactive(T) + newy = recursive_map(f!!, (; y, xs), args...) + @assert newy === y + return nothing +end + +function recursive_map!( + seen::IdDict, f!!::F, y::T, xs::XTup{N,T}, args::Vararg{Any,M} +) where {F,N,T,M} + check_notactive(T) + newy = recursive_map(seen, f!!, (; y, xs), args...) + @assert newy === y + return nothing +end + +## argument checkers +Base.@propagate_inbounds function check_initialized(xs, indices, value=true) + for xj in xs + for i in indices + if isinitialized(xj, i) != value + throw_initialized() + end + end + end + return nothing +end + +@inline function check_identical(x1, x2) + if x1 !== x2 + throw_identical() + end + return nothing +end + +@inline function check_notactive(::Type{T}) where {T} + if active_reg_inner(T, (), nothing, Val(true)) == ActiveState # justActive + throw_notactive() + end + return nothing +end + +@noinline function throw_initialized() + msg = "recursive_map(!) called on objects whose undefined fields/unassigned elements " + msg *= "don't line up" + throw(ArgumentError(msg)) +end + +@noinline function throw_identical() + msg = "recursive_map(!) called on objects whose topology don't match" + throw(ArgumentError(msg)) +end + +@noinline function throw_notactive() + msg = "recursive_map! called on objects containing immutable differentiable elements" + throw(ArgumentError(msg)) +end + +### make_zero(!) +## entry points, with default handling of leaves +function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} + if EnzymeCore.isvectortype(T) + if length(args) > 0 # pick up custom methods for custom vector types + new = EnzymeCore.make_zero(prev) + else # default implementation + # convert because zero may produce different eltype when eltype(T) is abstract + new = convert(T, zero(prev)) + end + else + new = recursive_map(make_zero_f!!, (prev,), args...) + end + return new::T +end + +function EnzymeCore.make_zero!(prev::T) where {T} + @assert !EnzymeCore.isscalartype(T) # sanity check + if EnzymeCore.isvectortype(T) # default implementation + fill!(prev, false) + else + recursive_map!(make_zero_f!!, prev, (prev,)) + end + return nothing +end + +## low-level interface, for bringing your own IdDict +function EnzymeCore.make_zero( + ::Type{T}, seen::IdDict, prev::T, args::Vararg{Any,M} +) where {T,M} + return recursive_map(seen, make_zero_f!!, (prev,), args...)::T +end + +function EnzymeCore.make_zero!(prev, seen::IdDict) + recursive_map!(seen, make_zero_f!!, prev, (prev,)) + return nothing +end + +## the mapped function: assert valid leaf type and call back into make_zero(!) +function make_zero_f!!(prev::T) where {T} + @assert EnzymeCore.isvectortype(T) # otherwise infinite loop + return EnzymeCore.make_zero(prev)::T +end + +function make_zero_f!!(pout::T, pin::T) where {T} + @assert !EnzymeCore.isscalartype(T) # not appropriate for in-place handler + @assert EnzymeCore.isvectortype(T) # otherwise infinite loop + @assert pout === pin + EnzymeCore.make_zero!(pout) + return pout::T +end + +end # module RecursiveMap diff --git a/test/abi.jl b/test/abi.jl index 1c62741ef1..b6898ac1ba 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -489,19 +489,6 @@ mulsin(x) = sin(x[1] * x[2]) @test Enzyme.autodiff(ForwardWithPrimal, () -> Enzyme.within_autodiff())[1] end -mutable struct ConstVal - x::Float64 - const y::Float64 -end - -@testset "Make Zero" begin - v = ConstVal(2.0, 3.0) - dv = make_zero(v) - @test dv isa ConstVal - @test dv.x ≈ 0.0 - @test dv.y ≈ 0.0 -end - @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) diff --git a/test/recursive_map.jl b/test/recursive_map.jl new file mode 100644 index 0000000000..3111263bd0 --- /dev/null +++ b/test/recursive_map.jl @@ -0,0 +1,690 @@ +module RecursiveMapTests + +using Enzyme +using StaticArrays +using Test + +# Universal getters/setters for built-in and custom containers/wrappers +getx(w::Base.RefValue) = w[] +getx(w::Core.Box) = w.contents +getx(w) = first(w) +gety(w) = last(w) + +setx!(w::Base.RefValue, x) = (w[] = x) +setx!(w::Core.Box, x) = (w.contents = x) +setx!(w, x) = (w[begin] = x) +sety!(w, y) = (w[end] = y) + +# non-isbits MArray doesn't support setindex!, so requires a little hack +function setx!(w::MArray{S,T}, x) where {S,T} + if isbitstype(T) + w[begin] = x + else + w.data = (x, Base.tail(w.data)...) + end + return x +end + +function sety!(w::MArray{S,T}, y) where {S,T} + if isbitstype(T) + w[end] = y + else + w.data = (Base.front(w.data)..., y) + end + return y +end + +struct Empty end + +mutable struct MutableEmpty end + +Base.:(==)(::MutableEmpty, ::MutableEmpty) = true + +struct Wrapper{T} + x::T +end + +Base.:(==)(a::Wrapper, b::Wrapper) = (a === b) || (a.x == b.x) +getx(a::Wrapper) = a.x + +mutable struct MutableWrapper{T} + x::T +end + +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || (a.x == b.x) + +getx(a::MutableWrapper) = a.x +setx!(a::MutableWrapper, x) = (a.x = x) + +struct DualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +DualWrapper{T}(x::T, y) where {T} = DualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::DualWrapper, b::DualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::DualWrapper) = a.x +gety(a::DualWrapper) = a.y + +mutable struct MutableDualWrapper{Tx,Ty} + x::Tx + y::Ty +end + +MutableDualWrapper{T}(x::T, y) where {T} = MutableDualWrapper{T,typeof(y)}(x, y) + +function Base.:(==)(a::MutableDualWrapper, b::MutableDualWrapper) + return (a === b) || ((a.x == b.x) && (a.y == b.y)) +end + +getx(a::MutableDualWrapper) = a.x +gety(a::MutableDualWrapper) = a.y + +setx!(a::MutableDualWrapper, x) = (a.x = x) +sety!(a::MutableDualWrapper, y) = (a.y = y) + +struct Incomplete{T} + s::String + x::Float64 + w::T + z # not initialized + Incomplete(s, x, w) = new{typeof(w)}(s, x, w) +end + +function Base.:(==)(a::Incomplete, b::Incomplete) + (a === b) && return true + ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct MutableIncomplete{T} + s::String + const x::Float64 + y::Float64 + z # not initialized + w::T + function MutableIncomplete(s, x, y, w) + ret = new{typeof(w)}(s, x, y) + ret.w = w + return ret + end +end + +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) + (a === b) && return true + if (a.s != b.s) || (a.x != b.x) || (a.y != b.y) || (a.w != b.w) + return false + end + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + return true +end + +mutable struct CustomVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) + +function Enzyme.EnzymeCore.isvectortype(::Type{CustomVector{T}}) where {T} + return Enzyme.EnzymeCore.isscalartype(T) +end + +function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV<:CustomVector{<:AbstractFloat}} + @info "make_zero(::CustomVector)" + return CustomVector(zero(prev.data))::CV +end + +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) + @info "make_zero!(::CustomVector)" + fill!(prev.data, false) + return nothing +end + +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64, BigFloat, Complex{BigFloat}] + +const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivearr = [inactivetup] + +const wrappers = [ + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + + (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + + # StaticArrays extension + (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), + (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), + (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), + (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), + (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), + (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), + (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), + (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), +] + +@static if VERSION < v"1.11-" +else +_memory(x::Vector) = Memory{eltype(x)}(x) +push!( + wrappers, + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), +) +end + +function test_make_zero() + @testset "scalars" begin + @testset "$T" for T in scalartypes + x = oneunit(T) + x_makez = make_zero(x) + @test typeof(x_makez) === T # correct type + @test x_makez == zero(T) # correct value + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + w = wrapper.f(x) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === T # correct type + @test getx(w_makez) == zero(T) # correct value + @test getx(w) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + @testset "doubly included in $(dualwrapper.name)" for + dualwrapper in filter(w -> (w.N == 2), wrappers) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type + @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type + @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value + @test getx(w_outer) === d_middle # no mutation of original + @test getx(d_middle) === gety(d_middle) # no mutation of original + @test getx(d_middle) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for wrapper in wrappers + if wrapper.N == 1 + w = wrapper.f(inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === inactivearr # no mutation of original + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + w_makez = make_zero(w) + if wrapper.typed == true + @test w_makez === w # preserved wrapper identity if guaranteed const + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved topology + @test getx(w_makez) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactivearr # no mutation of original + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test getx(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + end + end + @testset "copy_if_inactive $value" for (value, args) in [ + ("unspecified", ()), + ("= false", (Val(false),)), + ("= true", (Val(true),)), + ] + a = [1.0] + w = Any[a, inactivearr, inactivearr] + w_makez = make_zero(w, args...) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(w_makez[1]) === typeof(a) # correct type + @test w_makez[1] == [0.0] # correct value + @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w[1] === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + @test w[2] === w[3] # no mutation of original + @test w[2] === inactivearr # no mutation of original + @test inactivearr[1] === inactivetup # no mutation of original + if args == (Val(true),) + @test typeof(w_makez[2]) === typeof(inactivearr) # correct type + @test w_makez[2] == inactivearr # correct value + @test w_makez[2][1] !== inactivetup # correct identity + else + @test w_makez[2] === inactivearr # correct value/type/identity + end + end + end + @testset "heterogeneous containers" begin + scalars, scalarsz = oneunit.(scalartypes), zero.(scalartypes) + wraps, wrapsz = Wrapper.(scalars), Wrapper.(scalarsz) + mwraps, mwrapsz = MutableWrapper.(scalars), MutableWrapper.(scalarsz) + items = (inactivetup..., scalars..., wraps..., mwraps...) + itemsz = (inactivetup..., scalarsz..., wrapsz..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + c_makez = make_zero(c) + @test typeof(c_makez) === typeof(c) # correct type + @test all(typeof(czj) === typeof(cj) for (czj, cj) in zip(c_makez, c)) # correct type + @test c_makez == cz # correct value + @test all(czj === inj for (czj, inj) in zip(c_makez, inactivetup)) # preserved inactive identities + @test all(cj === itj for (cj, itj) in zip(c, items)) # no mutation of original + @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + w_makez = @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + make_zero(w) + catch e + showerror(stderr, e) + end + if wrapper.N == 1 + xz, yz = getx(w_makez) + x, y = getx(w) + else + xz, yz = getx(w_makez), gety(w_makez) + x, y = getx(w), gety(w) + end + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(xz) === typeof(w) # correct type + @test typeof(yz) === typeof(a) # correct type + @test xz === w_makez # correct self-reference + @test yz == [0.0] # correct value + @test x === w # no mutation of original + @test y === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + a_makez = make_zero(typeof(a), seen, a) + @test typeof(a_makez) === typeof(a) # correct type + @test a_makez == [0.0] # correct value + @test a[1] === 1.0 # no mutation of original + @test haskey(seen, a) # original added to IdDict + @test seen[a] === (a_makez,) # original points to zeroed value + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # include optional arg Val(false) to avoid calling the custom method directly; + # it should still be invoked + v_makez = @test_logs (:info, "make_zero(::CustomVector)") make_zero(v, Val(false)) + @test typeof(v_makez) === typeof(v) # correct type + @test typeof(v_makez.data) === typeof(a) # correct type + @test v_makez == CustomVector([0.0]) # correct value + @test v.data === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + arr_makez = make_zero(arr) + @views begin + @test typeof(arr_makez) === typeof(arr) # correct type + @test all(typeof.(arr_makez[1:3]) .=== typeof.(values)) # correct type + @test arr_makez[1:3] == ["a", 0.0, [0.0]] # correct value + @test !isassigned(arr_makez, 4) # propagated undefined + @test all(arr[1:3] .=== values) # no mutation of original + @test !isassigned(arr, 4) # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, propagated undefined + @test incomplete == MutableIncomplete("a", 1.0, 1.0, a) # no mutation of original + @test incomplete.w === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end + end + return nothing +end + +function test_make_zero!() + @testset "nested types" begin + @testset "$T in $(wrapper.name)" for + T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + x = oneunit(T) + if wrapper.mutable + w = wrapper.f(x) + make_zero!(w) + @test typeof(getx(w)) === T # preserved type + @test getx(w) == zero(T) # correct value + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( + filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) + ) + w_inner = wrapper.f(x) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end + d_inner = dualwrapper.f(x, x) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end + if wrapper.mutable && !dualwrapper.mutable + # some code paths can only be hit with three layers of wrapping: + # mutable(immutable(mutable(scalar))) + @assert !dualwrapper.mutable # sanity check + @testset "all wrapped in $(outerwrapper.name)" for + outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + w_inner = wrapper.f(x) + d_middle = dualwrapper.f(w_inner, w_inner) + w_outer = outerwrapper.f(d_middle) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_middle) # preserved type + @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type + @test typeof(getx(getx(getx(w_outer)))) === T # preserved type + @test getx(getx(getx(w_outer))) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === w_inner # preserved identity + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + end + end + end + end + end + @testset "inactive" begin + @testset "in $(wrapper.name)" for + wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + if wrapper.N == 1 + w = wrapper.f(inactivearr) + make_zero!(w) + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + else # wrapper.N == 2 + @testset "multiple references" begin + w = wrapper.f(inactivearr, inactivearr) + make_zero!(w) + @test getx(w) === gety(w) # preserved topology + @test getx(w) === inactivearr # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + @testset "alongside active" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end + end + end + end + @testset "heterogeneous containers" begin + mwraps = MutableWrapper.(oneunit.(scalartypes)) + mwrapsz = MutableWrapper.(zero.(scalartypes)) + items = (inactivetup..., mwraps...) + itemsz = (inactivetup..., mwrapsz...) + labels = Symbol.("i" .* string.(1:length(items))) + @testset "$name" for (name, c, cz) in [ + ("Tuple", Tuple(items), Tuple(itemsz)), + ("NamedTuple", NamedTuple(labels .=> items), NamedTuple(labels .=> itemsz)), + ("Array", collect(items), collect(itemsz)), + ] + make_zero!(c) + @test all(cj === itj for (cj, itj) in zip(c, items)) # preserved identities + @test c == cz # correct value + end + end + @testset "circular references" begin + @testset "$(wrapper.name)" for wrapper in ( + filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + ) + a = [1.0] + if wrapper.N == 1 + w = wrapper.f(nothing) + setx!(w, (w, a)) + else + w = wrapper.f(nothing, a) + setx!(w, w) + end + @test_nowarn try + # catch errors to get failed test instead of "exception outside of a @test" + make_zero!(w) + catch e + showerror(stderr, e) + end + if wrapper.N == 1 + x, y = getx(w) + else + x, y = getx(w), gety(w) + end + @test x === w # preserved self-referential identity + @test y === a # preserved identity + @test a[1] === 0.0 # correct value + end + end + @testset "bring your own IdDict" begin + a = [1.0] + seen = IdDict() + make_zero!(a, seen) + @test a[1] === 0.0 # correct value + @test haskey(seen, a) # object added to IdDict + @test seen[a] === (a,) # object points to zeroed value, i.e., itself + end + @testset "custom leaf type" begin + a = [1.0] + v = CustomVector(a) + # bringing own IdDict to avoid calling the custom method directly; + # it should still be invoked + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, IdDict()) + @test v.data === a # preserved identity + @test a[1] === 0.0 # correct value + end + @testset "undefined fields/unassigned elements" begin + @testset "array w inactive/active/mutable/unassigned" begin + a = [1.0] + values = ("a", 1.0, a) + arr = Vector{Any}(undef, 4) + arr[1:3] .= values + make_zero!(arr) + @views begin + @test all(typeof.(arr[1:3]) .=== typeof.(values)) # preserved types + @test arr[1:3] == ["a", 0.0, [0.0]] # correct value + @test arr[3] === a # preserved identity + @test !isassigned(arr, 4) # preserved unassigned + end + end + @testset "struct w inactive/active/mutable/undefined" begin + a = [1.0] + incompletearr = [Incomplete("a", 1.0, a)] + make_zero!(incompletearr) + @test incompletearr == [Incomplete("a", 0.0, [0.0])] # correct value, preserved undefined + @test incompletearr[1].w === a # preserved identity + end + @testset "mutable struct w inactive/const active/active/mutable/undefined" begin + a = [1.0] + incomplete = MutableIncomplete("a", #=const=#1.0, 1.0, a) + make_zero!(incomplete) + @test incomplete == MutableIncomplete("a", 0.0, 0.0, [0.0]) # correct value, preserved undefined + @test incomplete.w === a # preserved identity + end + @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin + # old implementation of make_zero! triggered #1935 + # new implementation would work regardless due to limited use of justActive + a = [1.0] + incomplete = Incomplete("a", 1.0, a) + incompletetuparr = [(incomplete,)] + make_zero!(incompletetuparr) + @test typeof(incompletetuparr[1]) === typeof((incomplete,)) # preserved type + @test incompletetuparr == [(Incomplete("a", 0.0, [0.0]),)] # correct value + @test incompletetuparr[1][1].w === a # preserved identity + end + end + @testset "active/mixed type error" begin + @test_throws ArgumentError make_zero!((1.0,)) + @test_throws ArgumentError make_zero!((1.0, [1.0])) + @test_throws ArgumentError make_zero!((Incomplete("a", 1.0, 1.0im),)) # issue #1935 + end + return nothing +end + +@testset "make_zero" test_make_zero() +@testset "make_zero!" test_make_zero!() + +end # module RecurisveMapTests diff --git a/test/runtests.jl b/test/runtests.jl index d28461d26a..e8b69a5441 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,7 @@ end include("abi.jl") include("typetree.jl") include("optimize.jl") +include("recursive_map.jl") include("rules.jl") include("rrules.jl") From b3f7426c45d64bb0d030e318c6234246d8c9e41d Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 31 Oct 2024 12:23:51 -0700 Subject: [PATCH 2/8] Fix typos and some minor things --- ext/EnzymeStaticArraysExt.jl | 2 +- lib/EnzymeCore/src/EnzymeCore.jl | 2 +- src/recursive_map.jl | 4 ++-- test/recursive_map.jl | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index a2d65a4653..3ab82c392e 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -40,7 +40,7 @@ end @inline function Enzyme.EnzymeCore.isvectortype( ::Type{<:Union{SArray{S,T},MArray{S,T}}} ) where {S,T} - return isbitstype(T) && Enzyme.Compiler.RecursiveMap.isscalartype(T) + return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T) end end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index 298a45d5f0..dc2668e60c 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -538,7 +538,7 @@ By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true A new plain array type, for example a GPU array, may extend this as follows: ```julia -@inline EnzymeCore.isvectortype(::Type{<:GPUArray{U}}) where {U} = isscalartype(U) +@inline EnzymeCore.isvectortype(::Type{<:NewArray{U}}) where {U} = EnzymeCore.isscalartype(U) ``` Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not diff --git a/src/recursive_map.jl b/src/recursive_map.jl index aa6ad2ce50..2f7fba3cec 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -309,8 +309,8 @@ end # generic handling of mutable structs, arrays, and memory @inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T @inline _similar(x::T) where {T<:Arraylike} = similar(x)::T -@inline _eachindex(xs::T...) where {T} = 1:fieldcount(T) -@inline _eachindex(xs::Arraylike...) = eachindex(xs...) +@inline _eachindex(::Vararg{T,M}) where {T,M} = 1:fieldcount(T) +@inline _eachindex(xs::Vararg{Arraylike,M}) where {M} = eachindex(xs...) @inline isinitialized(x, i) = isdefined(x, i) Base.@propagate_inbounds isinitialized(x::Arraylike, i) = isassigned(x, i) @inline getvalue(x, i) = getfield(x, i) diff --git a/test/recursive_map.jl b/test/recursive_map.jl index 3111263bd0..278499201a 100644 --- a/test/recursive_map.jl +++ b/test/recursive_map.jl @@ -687,4 +687,4 @@ end @testset "make_zero" test_make_zero() @testset "make_zero!" test_make_zero!() -end # module RecurisveMapTests +end # module RecursiveMapTests From 1778ebeb74c7b4341cd7d9287283d8493606a3d5 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Thu, 31 Oct 2024 14:30:17 -0700 Subject: [PATCH 3/8] Fix Holomorphic tests Eventually, `recursive_accumulate` should be rewritten on top of a new `VectorSpace` wrapper built on `recursive_map`. Until then, this will do. --- src/Enzyme.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 2e8643744b..64f6c0be20 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -451,10 +451,10 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) # compute the correct complex derivative in reverse mode by propagating the conjugate return values # then subtracting twice the imaginary component to get the correct result - for (k, v) in seen + for (k, (v,)) in seen Compiler.recursive_accumulate(k, v, refn_seed) end - for (k, v) in seen2 + for (k, (v,)) in seen2 Compiler.recursive_accumulate(k, v, imfn_seed) end From 9f60f1d55f920c7dd3e9d30451cf0ab7cee7383d Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 2 Nov 2024 09:27:47 -0700 Subject: [PATCH 4/8] Docstring tweaks --- lib/EnzymeCore/src/EnzymeCore.jl | 24 +++++++++++++----------- src/recursive_map.jl | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index dc2668e60c..d4eb65941b 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -504,25 +504,27 @@ function autodiff_deferred_thunk end make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T -Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies -what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. +Recursively make a copy of the value `prev` of type `T` in which all differentiable values +are set to zero. The argument `copy_if_inactive` specifies what to do if the type `T` or any +of its constituent parts is guaranteed to be inactive: use `prev`s instance (the default) or +make a copy. -Extending this method for custom types is rarely needed. For new plain array types like GPU -arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type implements -`Base.zero`. +Extending this method for custom types is rarely needed. For new types that shouldn't be +recursed into, such as a GPU array type, extending [`isvectortype`](@ref) is sufficient as +long as the type implements `Base.zero`. """ function make_zero end """ make_zero!(val::T, seen::IdDict=IdDict())::Nothing -Recursively set a variables differentiable fields to zero. Only applicable for types `T` +Recursively set a variable's differentiable fields to zero. Only applicable for types `T` that are mutable or hold all differentiable values in mutable containers (e.g., `Tuple{Vector{Float64}}`). -Extending this method for custom types is rarely needed. For new plain mutable array types -like GPU arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type -implements `Base.zero` and `Base.fill!`. +Extending this method for custom types is rarely needed. For new mutable types that +shouldn't be recursed into, such as a GPU array type, extending [`isvectortype`](@ref) is +sufficient as long as the type implements `Base.zero` and `Base.fill!`. """ function make_zero! end @@ -535,7 +537,7 @@ and [`make_zero!`](@ref) recurse through an object. By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or `T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`. -A new plain array type, for example a GPU array, may extend this as follows: +A new leaf type, such as example a GPU array type, may extend this as follows: ```julia @inline EnzymeCore.isvectortype(::Type{<:NewArray{U}}) where {U} = EnzymeCore.isscalartype(U) @@ -564,7 +566,7 @@ function isvectortype end Trait defining a subset of [`isvectortype`](@ref) types that should not be considered composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero values of the type in-place. For example, `BigFloat` is a mutable type but does not support -in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that +in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that `make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] By default, `isscalartype(T) == true` for `T <: AbstractFloat` and diff --git a/src/recursive_map.jl b/src/recursive_map.jl index 2f7fba3cec..f027d8ccd6 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -385,8 +385,9 @@ function `f!!` over every differentiable value encountered and updating new muta This is a wrapper that calls `recursive_map([seen,] f!!, (; y, xs), isleaftype, Val(copy_if_inactive))`, but only accepts -types `T` that are mutable (or, trivially, entirely non-differentiable), and enforces a -fully in-place update of `y`. See [`recursive_map`](@ref) for details. +types `T` in which all differentiable values can be updated in-place (including, trivially, +types that don't contain any differentiable values), and enforces a fully in-place update of +`y`. See [`recursive_map`](@ref) for details. """ function recursive_map!(f!!::F, y::T, xs::XTup{N,T}, args::Vararg{Any,M}) where {F,N,T,M} check_notactive(T) From dbac66be8140686a7673728426e2432fdc222052 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 2 Nov 2024 18:02:07 -0700 Subject: [PATCH 5/8] Add fast path for recursive_map on vector types --- lib/EnzymeCore/src/EnzymeCore.jl | 5 +++++ src/recursive_map.jl | 19 +++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index d4eb65941b..3348cb54ed 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -556,6 +556,11 @@ directly on their backing arrays, just like how Enzyme treats them when differen example, structured matrix wrappers and sparse array types that are backed by `Array` should not extend `isvectortype`. +If a vector type `T` is also non-differentiable, `isvectortype` takes precedence, that is, +`make_zero(!)` will attempt to zero its values rather than share/copy them (out-of-place) or +skip them (in-place). This is for performance reasons, but should almost never be relevant +for behavior, as the two traits should be mutually exclusive. + See also [`isscalartype`](@ref). """ function isvectortype end diff --git a/src/recursive_map.jl b/src/recursive_map.jl index f027d8ccd6..f804bbf959 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -26,8 +26,12 @@ function `f` over every differentiable value encountered and building a new obje from the resulting values `yi = f(x1i, ..., xNi)`. The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered -differentiable leaf nodes at which recursion terminates and `f` is invoked. See the -docstring for that and the related [`EnzymeCore.isscalartype`](@ref) for more information. +differentiable leaf nodes at which recursion terminates and `f` is invoked. This trait takes +precedence over being non-differentiable, that is, if a type is both, it's values are passed +to `f`, not copied/shared according to `copy_if_inactive` (this is for performance reasons +and should almost never be relevant for behavior, as the two traits should be mutually +exclusive). See the docstring for [`EnzymeCore.isvectortype`](@ref) and the related +[`EnzymeCore.isscalartype`](@ref) for more information. An existing object `y::T` can be passed by replacing the tuple `xs` with a NamedTuple `(; y, xs)`, in which case `y` is updated "partially-in-place": any parts of `y` that are @@ -116,7 +120,9 @@ end f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false) ) where {F,N,T} # determine whether or not an IdDict is needed for this T - if isbitstype(T) || ( + if isvectortype(T) + y = recursive_map_leaf(nothing, f, yxs) + elseif isbitstype(T) || ( guaranteed_const_nongen(T, nothing) && !needscopy(yxs, copy_if_inactive) ) y = recursive_map(nothing, f, yxs, copy_if_inactive) @@ -135,7 +141,8 @@ end ) where {F,N,T} # determine whether to continue recursion, copy/share, or retrieve from cache xs = xtup(yxs) - if guaranteed_const_nongen(T, nothing) + if (!isvectortype(T)) && guaranteed_const_nongen(T, nothing) + # check `!isvectortype` first for consistency with the above method y = maybecopy(seen, yxs, copy_if_inactive) elseif isbitstype(T) # no need to track identity or pass y in this branch y = recursive_map_inner(nothing, f, xs, copy_if_inactive) @@ -275,7 +282,7 @@ end function recursive_map_leaf(seen, f::F, xs::XTup{N,T}) where {F,N,T} # out-of-place y = f(xs...) - if !isbitstype(T) + if (!isnothing(seen)) && (!isbitstype(T)) cache!(seen, y, xs) end return y::T @@ -291,7 +298,7 @@ function recursive_map_leaf(seen, f!!::F, (; y, xs)::YXTup{N,T}) where {F,N,T} @assert newy === y end end - if !isbitstype(T) + if (!isnothing(seen)) && (!isbitstype(T)) cache!(seen, newy, xs) end return newy::T From 2f4d7eb6fd757a6d14d991c2794389d6598edfd5 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 2 Nov 2024 18:07:55 -0700 Subject: [PATCH 6/8] Make recursive_map use guaranteed_const by default But make it customizable via an argument. Dynamic dispatch in guaranteed_const_nongen is too much of a performance killer --- src/recursive_map.jl | 56 ++++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/src/recursive_map.jl b/src/recursive_map.jl index f804bbf959..31c466aa77 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -1,7 +1,7 @@ module RecursiveMap using EnzymeCore: EnzymeCore, isvectortype, isscalartype -using ..Compiler: ActiveState, active_reg_inner, guaranteed_const_nongen, splatnew +using ..Compiler: ActiveState, active_reg_inner, guaranteed_const, splatnew """ y = recursive_map( @@ -9,12 +9,14 @@ using ..Compiler: ActiveState, active_reg_inner, guaranteed_const_nongen, splatn f, xs::NTuple{N,T}, ::Val{copy_if_inactive}=Val(false), + isinactive=guaranteed_const, )::T where {T,N,copy_if_inactive} newy = recursive_map( [seen::IdDict,] f, (; y, xs)::@NamedTuple{y::T,xs::NTuple{N,T}}, ::Val{copy_if_inactive}=Val(false), + isinactive=guaranteed_const, )::T where {T,N,copy_if_inactive} !!! warning @@ -27,11 +29,12 @@ from the resulting values `yi = f(x1i, ..., xNi)`. The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered differentiable leaf nodes at which recursion terminates and `f` is invoked. This trait takes -precedence over being non-differentiable, that is, if a type is both, it's values are passed -to `f`, not copied/shared according to `copy_if_inactive` (this is for performance reasons -and should almost never be relevant for behavior, as the two traits should be mutually -exclusive). See the docstring for [`EnzymeCore.isvectortype`](@ref) and the related -[`EnzymeCore.isscalartype`](@ref) for more information. +precedence over being non-differentiable as defined by `isinactive`, that is, if a type is +both, it's values are passed to `f`, not copied/shared according to `copy_if_inactive` (this +is for performance reasons and should almost never be relevant for behavior, as the two +traits should be mutually exclusive). See the docstring for +[`EnzymeCore.isvectortype`](@ref) and the related [`EnzymeCore.isscalartype`](@ref) for more +information. An existing object `y::T` can be passed by replacing the tuple `xs` with a NamedTuple `(; y, xs)`, in which case `y` is updated "partially-in-place": any parts of `y` that are @@ -98,6 +101,9 @@ that the type notionally represents. several non-differentiable parts, object identity is tracked across the multiple deep-copies such that the object reference graph is reproduced also within the inactive parts.) + +* `isinactive` (optional): Callable that determines whether a type is non-differentiable and + hence treated according to `copy_if_inactive`. """ function recursive_map end @@ -117,14 +123,15 @@ end ## main entry point @inline function recursive_map( - f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false) -) where {F,N,T} + f::F, + yxs::XTupOrYXTup{N,T}, + copy_if_inactive::Val=Val(false), + isinactive::C=guaranteed_const, +) where {F,N,T,C} # determine whether or not an IdDict is needed for this T if isvectortype(T) y = recursive_map_leaf(nothing, f, yxs) - elseif isbitstype(T) || ( - guaranteed_const_nongen(T, nothing) && !needscopy(yxs, copy_if_inactive) - ) + elseif isbitstype(T) || (isinactive(T) && !needscopy(yxs, copy_if_inactive)) y = recursive_map(nothing, f, yxs, copy_if_inactive) else y = recursive_map(IdDict(), f, yxs, copy_if_inactive) @@ -138,18 +145,19 @@ end f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false), -) where {F,N,T} + isinactive::C=guaranteed_const, +) where {F,N,T,C} # determine whether to continue recursion, copy/share, or retrieve from cache xs = xtup(yxs) - if (!isvectortype(T)) && guaranteed_const_nongen(T, nothing) + if (!isvectortype(T)) && isinactive(T) # check `!isvectortype` first for consistency with the above method y = maybecopy(seen, yxs, copy_if_inactive) elseif isbitstype(T) # no need to track identity or pass y in this branch - y = recursive_map_inner(nothing, f, xs, copy_if_inactive) + y = recursive_map_inner(nothing, f, xs, copy_if_inactive, isinactive) elseif hascache(seen, xs) y = getcached(seen, xs) else - y = recursive_map_inner(seen, f, yxs, copy_if_inactive) + y = recursive_map_inner(seen, f, yxs, copy_if_inactive, isinactive) end return y::T end @@ -222,19 +230,19 @@ end end @inline function recursive_map_immutable( - seen, f::F, yxs::XTupOrYXTup{N,T}, args::Vararg{Any,M} + seen, f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive, args::Vararg{Any,M} ) where {F,N,T,M} # immutable handler: construct y/newy @assert !ismutabletype(T) x1, xtail... = xtup(yxs) nf = fieldcount(T) if nf == 0 # nothing to do; assume inactive - y = maybecopy(seen, yxs, args...) + y = maybecopy(seen, yxs, copy_if_inactive) elseif isdefined(x1, nf) # fast path when all fields are defined check_initialized(xtail, nf) fieldtup = ntuple(Val(nf)) do i @inline - recursive_map_index(i, seen, f, yxs, args...) + recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) end y = splatnew(T, fieldtup) else @@ -242,7 +250,7 @@ end @inbounds for i in 1:nf if isdefined(x1, i) check_initialized(xtail, i) - flds[i] = recursive_map_index(i, seen, f, yxs, args...) + flds[i] = recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) else nf = i - 1 # rest of tail must be undefined values break @@ -455,6 +463,8 @@ end end ### make_zero(!) +# XXX: you can pass a custom `isinactive` as the last argument to both `make_zero` and +# `make_zero!`, but it's currently undocumented ## entry points, with default handling of leaves function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} if EnzymeCore.isvectortype(T) @@ -470,12 +480,12 @@ function EnzymeCore.make_zero(prev::T, args::Vararg{Any,M}) where {T,M} return new::T end -function EnzymeCore.make_zero!(prev::T) where {T} +function EnzymeCore.make_zero!(prev::T, args::Vararg{Any,M}) where {T,M} @assert !EnzymeCore.isscalartype(T) # sanity check if EnzymeCore.isvectortype(T) # default implementation fill!(prev, false) else - recursive_map!(make_zero_f!!, prev, (prev,)) + recursive_map!(make_zero_f!!, prev, (prev,), Val(false), args...) end return nothing end @@ -487,8 +497,8 @@ function EnzymeCore.make_zero( return recursive_map(seen, make_zero_f!!, (prev,), args...)::T end -function EnzymeCore.make_zero!(prev, seen::IdDict) - recursive_map!(seen, make_zero_f!!, prev, (prev,)) +function EnzymeCore.make_zero!(prev, seen::IdDict, args::Vararg{Any,M}) where {M} + recursive_map!(seen, make_zero_f!!, prev, (prev,), Val(false), args...) return nothing end From 047666561adcb810708bc9de7e324086f201e0aa Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 2 Nov 2024 18:43:36 -0700 Subject: [PATCH 7/8] fixup! Make recursive_map use guaranteed_const by default --- src/recursive_map.jl | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/recursive_map.jl b/src/recursive_map.jl index 31c466aa77..8509ac8c99 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -1,7 +1,7 @@ module RecursiveMap using EnzymeCore: EnzymeCore, isvectortype, isscalartype -using ..Compiler: ActiveState, active_reg_inner, guaranteed_const, splatnew +using ..Compiler: ActiveState, guaranteed_const, guaranteed_nonactive, splatnew """ y = recursive_map( @@ -132,9 +132,9 @@ end if isvectortype(T) y = recursive_map_leaf(nothing, f, yxs) elseif isbitstype(T) || (isinactive(T) && !needscopy(yxs, copy_if_inactive)) - y = recursive_map(nothing, f, yxs, copy_if_inactive) + y = recursive_map(nothing, f, yxs, copy_if_inactive, isinactive) else - y = recursive_map(IdDict(), f, yxs, copy_if_inactive) + y = recursive_map(IdDict(), f, yxs, copy_if_inactive, isinactive) end return y::T end @@ -388,6 +388,7 @@ end xs::NTuple{N,T}, isleaftype=Returns(false), ::Val{copy_if_inactive}=Val(false), + isinactive=guaranteed_const, )::Nothing where {T,N,copy_if_inactive} !!! warning @@ -405,7 +406,7 @@ types that don't contain any differentiable values), and enforces a fully in-pla `y`. See [`recursive_map`](@ref) for details. """ function recursive_map!(f!!::F, y::T, xs::XTup{N,T}, args::Vararg{Any,M}) where {F,N,T,M} - check_notactive(T) + check_nonactive(T) newy = recursive_map(f!!, (; y, xs), args...) @assert newy === y return nothing @@ -414,7 +415,7 @@ end function recursive_map!( seen::IdDict, f!!::F, y::T, xs::XTup{N,T}, args::Vararg{Any,M} ) where {F,N,T,M} - check_notactive(T) + check_nonactive(T) newy = recursive_map(seen, f!!, (; y, xs), args...) @assert newy === y return nothing @@ -439,9 +440,9 @@ end return nothing end -@inline function check_notactive(::Type{T}) where {T} - if active_reg_inner(T, (), nothing, Val(true)) == ActiveState # justActive - throw_notactive() +@inline function check_nonactive(::Type{T}) where {T} + if !guaranteed_nonactive(T) + throw_nonactive() end return nothing end @@ -457,7 +458,7 @@ end throw(ArgumentError(msg)) end -@noinline function throw_notactive() +@noinline function throw_nonactive() msg = "recursive_map! called on objects containing immutable differentiable elements" throw(ArgumentError(msg)) end From 530c780380f4cf7cd1374fbe524b8b8d514bcc31 Mon Sep 17 00:00:00 2001 From: Daniel Wennberg Date: Sat, 2 Nov 2024 19:55:40 -0700 Subject: [PATCH 8/8] Factor out default arg values from logic Avoid the situation where a method with optional args calls another method where the same args are optional. It's a recipe for silent bugs where a non-default arg gets dropped. --- src/recursive_map.jl | 63 ++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/src/recursive_map.jl b/src/recursive_map.jl index 8509ac8c99..282ab1d32f 100644 --- a/src/recursive_map.jl +++ b/src/recursive_map.jl @@ -122,30 +122,37 @@ else end ## main entry point -@inline function recursive_map( +function recursive_map( f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false), isinactive::C=guaranteed_const, ) where {F,N,T,C} - # determine whether or not an IdDict is needed for this T + # set default argument values; determine whether an IdDict is needed for this T if isvectortype(T) - y = recursive_map_leaf(nothing, f, yxs) + y = _recursive_map_leaf(nothing, f, yxs) elseif isbitstype(T) || (isinactive(T) && !needscopy(yxs, copy_if_inactive)) - y = recursive_map(nothing, f, yxs, copy_if_inactive, isinactive) + y = _recursive_map(nothing, f, yxs, copy_if_inactive, isinactive) else - y = recursive_map(IdDict(), f, yxs, copy_if_inactive, isinactive) + y = _recursive_map(IdDict(), f, yxs, copy_if_inactive, isinactive) end return y::T end -## recursive methods -@inline function recursive_map( +function recursive_map( seen::Union{Nothing,IdDict}, f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive::Val=Val(false), isinactive::C=guaranteed_const, +) where {F,N,T,C} + # set default argument values + return _recursive_map(seen, f, yxs, copy_if_inactive, isinactive)::T +end + +## recursive methods +@inline function _recursive_map( + seen, f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive, isinactive::C ) where {F,N,T,C} # determine whether to continue recursion, copy/share, or retrieve from cache xs = xtup(yxs) @@ -153,32 +160,32 @@ end # check `!isvectortype` first for consistency with the above method y = maybecopy(seen, yxs, copy_if_inactive) elseif isbitstype(T) # no need to track identity or pass y in this branch - y = recursive_map_inner(nothing, f, xs, copy_if_inactive, isinactive) + y = _recursive_map_inner(nothing, f, xs, copy_if_inactive, isinactive) elseif hascache(seen, xs) y = getcached(seen, xs) else - y = recursive_map_inner(seen, f, yxs, copy_if_inactive, isinactive) + y = _recursive_map_inner(seen, f, yxs, copy_if_inactive, isinactive) end return y::T end -@inline function recursive_map_inner( +@inline function _recursive_map_inner( seen, f::F, yxs::XTupOrYXTup{N,T}, args::Vararg{Any,M} ) where {F,N,T,M} # forward to appropriate handler for leaf vs. mutable vs. immutable type @assert !isabstracttype(T) @assert isconcretetype(T) if isvectortype(T) - y = recursive_map_leaf(seen, f, yxs) + y = _recursive_map_leaf(seen, f, yxs) elseif ismutabletype(T) - y = recursive_map_mutable(seen, f, yxs, args...) + y = _recursive_map_mutable(seen, f, yxs, args...) else - y = recursive_map_immutable(seen, f, yxs, args...) + y = _recursive_map_immutable(seen, f, yxs, args...) end return y::T end -@inline function recursive_map_mutable( +@inline function _recursive_map_mutable( seen, f::F, xs::XTup{N,T}, args::Vararg{Any,M} ) where {F,N,T,M} # out-of-place mutable handler: construct y @@ -191,7 +198,7 @@ end check_initialized(xtail, 1:nf) fieldtup = ntuple(Val(nf)) do i @inline - recursive_map_index(i, seen, f, xs, args...) + _recursive_map_index(i, seen, f, xs, args...) end y = splatnew(T, fieldtup) cache!(seen, y, xs) @@ -201,7 +208,7 @@ end @inbounds for i in _eachindex(y, xs...) if isinitialized(x1, i) check_initialized(xtail, i) - yi = recursive_map_index(i, seen, f, xs, args...) + yi = _recursive_map_index(i, seen, f, xs, args...) setvalue(y, i, yi) end end @@ -209,7 +216,7 @@ end return y::T end -@inline function recursive_map_mutable( +@inline function _recursive_map_mutable( seen, f!!::F, (; y, xs)::YXTup{N,T}, args::Vararg{Any,M} ) where {F,N,T,M} # in-place mutable handler: set/update values in y @@ -220,7 +227,7 @@ end # handle both structs, arrays, and memory through generic helpers if isinitialized(x1, i) check_initialized(xtail, i) - newyi = recursive_map_index(i, seen, f!!, (; y, xs), args...) + newyi = _recursive_map_index(i, seen, f!!, (; y, xs), args...) setvalue(y, i, newyi) else check_initialized((y,), i, false) @@ -229,7 +236,7 @@ end return y::T end -@inline function recursive_map_immutable( +@inline function _recursive_map_immutable( seen, f::F, yxs::XTupOrYXTup{N,T}, copy_if_inactive, args::Vararg{Any,M} ) where {F,N,T,M} # immutable handler: construct y/newy @@ -242,7 +249,7 @@ end check_initialized(xtail, nf) fieldtup = ntuple(Val(nf)) do i @inline - recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) + _recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) end y = splatnew(T, fieldtup) else @@ -250,7 +257,7 @@ end @inbounds for i in 1:nf if isdefined(x1, i) check_initialized(xtail, i) - flds[i] = recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) + flds[i] = _recursive_map_index(i, seen, f, yxs, copy_if_inactive, args...) else nf = i - 1 # rest of tail must be undefined values break @@ -261,17 +268,17 @@ end return y::T end -Base.@propagate_inbounds function recursive_map_index( +Base.@propagate_inbounds function _recursive_map_index( i, seen, f::F, xs::XTup, args::Vararg{Any,M} ) where {F,M} # out-of-place recursive handler: extract value i from each of the xs; call # recursive_map to obtain yi xis = getvalues(xs, i) - yi = recursive_map(seen, f, xis, args...) + yi = _recursive_map(seen, f, xis, args...) return yi::Core.Typeof(first(xis)) end -Base.@propagate_inbounds function recursive_map_index( +Base.@propagate_inbounds function _recursive_map_index( i, seen, f!!::F, (; y, xs)::YXTup, args::Vararg{Any,M} ) where {F,M} # partially-in-place recursive handler: extract value i from each of the xs and, if @@ -279,15 +286,15 @@ Base.@propagate_inbounds function recursive_map_index( xis = getvalues(xs, i) if isinitialized(y, i) yi = getvalue(y, i) - newyi = recursive_map(seen, f!!, (; y=yi, xs=xis), args...) + newyi = _recursive_map(seen, f!!, (; y=yi, xs=xis), args...) else - newyi = recursive_map(seen, f!!, xis, args...) + newyi = _recursive_map(seen, f!!, xis, args...) end return newyi::Core.Typeof(first(xis)) end ## leaf handlers -function recursive_map_leaf(seen, f::F, xs::XTup{N,T}) where {F,N,T} +function _recursive_map_leaf(seen, f::F, xs::XTup{N,T}) where {F,N,T} # out-of-place y = f(xs...) if (!isnothing(seen)) && (!isbitstype(T)) @@ -296,7 +303,7 @@ function recursive_map_leaf(seen, f::F, xs::XTup{N,T}) where {F,N,T} return y::T end -function recursive_map_leaf(seen, f!!::F, (; y, xs)::YXTup{N,T}) where {F,N,T} +function _recursive_map_leaf(seen, f!!::F, (; y, xs)::YXTup{N,T}) where {F,N,T} # partially-in-place if isbitstype(T) || isscalartype(T) newy = f!!(xs...)