diff --git a/ext/EnzymeStaticArraysExt.jl b/ext/EnzymeStaticArraysExt.jl index ef955ebd9b..14d18a2835 100644 --- a/ext/EnzymeStaticArraysExt.jl +++ b/ext/EnzymeStaticArraysExt.jl @@ -32,50 +32,11 @@ end end end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - return Base.zero(prev)::FT -end - -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive} - return Base.zero(prev)::FT -end -@inline function Enzyme.EnzymeCore.make_zero( - ::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false) -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - new = Base.zero(prev)::FT - seen[prev] = new - return new -end - -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT, seen -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end -@inline function Enzyme.EnzymeCore.make_zero!( - prev::FT -) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}} - Enzyme.EnzymeCore.make_zero!(prev, nothing) - return nothing +# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct, +# but in case their dedicated `zero` and `fill!` methods are more efficient than +# `make_zero(!)`s recursion, we opt into treating them as leaves. +@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S,T}}) where {S,T} + return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T) end end diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index f949664b6a..6829f88fea 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -506,28 +506,103 @@ function autodiff_thunk end function autodiff_deferred_thunk end """ + make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T -Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies -what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value. +Recursively make a copy of the value `prev::T` in which all differentiable values are +zeroed. The argument `copy_if_inactive` specifies what to do if the type `T` or any +of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s +instance (the default) or make a copy. + +Extending this method for custom types is rarely needed. If you implement a new type, such +as a GPU array type, on which `make_zero` should directly invoke `zero` when the eltype is +scalar, it is sufficient to implement `Base.zero` and make sure your type subtypes +`DenseArray`. (If subtyping `DenseArray` is not appropriate, extend +[`EnzymeCore.isvectortype`](@ref) instead.) """ function make_zero end """ - make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing + make_zero!(val::T, [seen::IdDict])::Nothing + +Recursively set a variable's differentiable values to zero. Only applicable for types `T` +that are mutable or hold all differentiable values in mutable storage (e.g., +`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over +parts of `val` that are guaranteed to be inactive. -Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`. +Extending this method for custom types is rarely needed. If you implement a new mutable +type, such as a GPU array type, on which `make_zero!` should directly invoke +`fill!(x, false)` when the eltype is scalar, it is sufficient to implement `Base.zero`, +`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is +not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.) """ function make_zero! end """ - make_zero(prev::T) + isvectortype(::Type{T})::Bool -Helper function to recursively make zero. -""" -@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive} - make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive)) +Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref) +and [`make_zero!`](@ref) recurse through an object. + +By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when +`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`. + +A new vector type, such as a GPU array type, should normally subtype `DenseArray` and +inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be +extended directly as follows: + +```julia +@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray} + U = eltype(T) + return isbitstype(U) && EnzymeCore.isscalartype(U) end +``` + +Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. + +Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at +which vector space operations like addition and scalar multiplication are supported, the +prototypical case being `Array`. Regular Julia structs with vector space-like semantics +should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act +directly on their backing arrays, just like how Enzyme treats them when differentiating. For +example, structured matrix wrappers and sparse array types that are backed by `Array` should +not extend `isvectortype`. + +See also [`isscalartype`](@ref). +""" +function isvectortype end + +""" + isscalartype(::Type{T})::Bool + +Trait defining a subset of [`isvectortype`](@ref) types that should not be considered +composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero +values of the type in-place. For example, `BigFloat` is a mutable type but does not support +in-place mutation through any Julia API; `isscalartype(BigFloat) == true` ensures that +`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat] + +By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete +types `T <: AbstractFloat`. + +A hypothetical new real number type with Enzyme support should usually subtype +`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate, +the function can be extended as follows: + +```julia +@inline EnzymeCore.isscalartype(::Type{NewReal}) = true +@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true +``` + +In either case, the type should implement `Base.zero`. + +See also [`isvectortype`](@ref). + +[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is +mentioned here only to demonstrate that it would be inappropriate to use traits like +`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing, +showing the need for a dedicated `isscalartype` trait. +""" +function isscalartype end function tape_type end diff --git a/src/Enzyme.jl b/src/Enzyme.jl index e24ff41cdb..17a205a8b1 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) # compute the correct complex derivative in reverse mode by propagating the conjugate return values # then subtracting twice the imaginary component to get the correct result - for (k, v) in seen - Compiler.recursive_accumulate(k, v, refn_seed) - end - for (k, v) in seen2 - Compiler.recursive_accumulate(k, v, imfn_seed) - end + Compiler.accumulate_seen!(refn_seed, seen) + Compiler.accumulate_seen!(imfn_seed, seen2) fused = fuse_complex_results(results, args...) diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 61d2f35ab7..8fa5331d77 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -414,7 +414,7 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_c return res end -Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world)::Bool where {T} +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_const_nongen(::Type{T}, world=nothing)::Bool where {T} rt = active_reg_inner(T, (), world) res = rt == AnyState return res @@ -427,6 +427,11 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState end +Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world=nothing)::Bool where {T} + rt = Enzyme.Compiler.active_reg_inner(T, (), world) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + """ Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) diff --git a/src/compiler.jl b/src/compiler.jl index b61ec5854f..6151ee140d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}( include("absint.jl") include("llvm/transforms.jl") include("llvm/passes.jl") -include("typeutils/make_zero.jl") +include("typeutils/recursive_maps.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world) diff --git a/src/internal_rules.jl b/src/internal_rules.jl index 04aca1a66a..1fdd874378 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -253,47 +253,6 @@ function EnzymeRules.augmented_primal( return EnzymeRules.AugmentedReturn(primal, shadow, shadow) end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:Array} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - seen[into] = (into, from) - for i in eachindex(from) - tup = accumulate_into(into[i], seen, from[i]) - @inbounds into[i] = tup[1] - @inbounds from[i] = tup[2] - end - end - return seen[into] -end - -@inline function accumulate_into( - into::RT, - seen::IdDict, - from::RT, -)::Tuple{RT,RT} where {RT<:AbstractFloat} - if !haskey(seen, into) - seen[into] = (into + from, RT(0)) - end - return seen[into] -end - -@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT} - if Enzyme.Compiler.guaranteed_const(RT) - return (into, from) - end - if !haskey(seen, into) - throw(AssertionError("Unknown type to accumulate into: $RT")) - end - return seen[into] -end - function EnzymeRules.reverse( config::EnzymeRules.RevConfig, func::Const{typeof(Base.deepcopy)}, @@ -302,15 +261,8 @@ function EnzymeRules.reverse( x::Annotation{Ty}, ) where {RT,Ty} if EnzymeRules.needs_shadow(config) - if EnzymeRules.width(config) == 1 - accumulate_into(x.dval, IdDict(), shadow) - else - for i = 1:EnzymeRules.width(config) - accumulate_into(x.dval[i], IdDict(), shadow[i]) - end - end + Compiler.accumulate_into!(x.dval, shadow) end - return (nothing,) end diff --git a/src/typeutils/make_zero.jl b/src/typeutils/make_zero.jl deleted file mode 100644 index 5c7b49a749..0000000000 --- a/src/typeutils/make_zero.jl +++ /dev/null @@ -1,587 +0,0 @@ -@inline function EnzymeCore.make_zero(x::FT)::FT where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero(x::Complex{FT})::Complex{FT} where {FT<:AbstractFloat} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{FT,N}, -)::Array{FT,N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::Array{Complex{FT},N}, -)::Array{Complex{FT},N} where {FT<:AbstractFloat,N} - return Base.zero(x) -end - - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, FT}, -)::GenericMemory{kind, FT} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -@inline function EnzymeCore.make_zero( - x::GenericMemory{kind, Complex{FT}}, -)::GenericMemory{kind, Complex{FT}} where {FT<:AbstractFloat,kind} - return Base.zero(x) -end -end - - -@inline function EnzymeCore.make_zero( - ::Type{Array{FT,N}}, - seen::IdDict, - prev::Array{FT,N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{FT,N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{Array{Complex{FT},N}}, - seen::IdDict, - prev::Array{Complex{FT},N}, - ::Val{copy_if_inactive} = Val(false), -)::Array{Complex{FT},N} where {copy_if_inactive,FT<:AbstractFloat,N} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, FT}}, - seen::IdDict, - prev::GenericMemory{kind, FT}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, FT} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -@inline function EnzymeCore.make_zero( - ::Type{GenericMemory{kind, Complex{FT}}}, - seen::IdDict, - prev::GenericMemory{kind, Complex{FT}}, - ::Val{copy_if_inactive} = Val(false), -)::GenericMemory{kind, Complex{FT}} where {copy_if_inactive,FT<:AbstractFloat,kind} - if haskey(seen, prev) - return seen[prev] - end - newa = Base.zero(prev) - seen[prev] = newa - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:AbstractFloat} - return RT(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{Complex{RT}}, - seen::IdDict, - prev::Complex{RT}, - ::Val{copy_if_inactive} = Val(false), -)::Complex{RT} where {copy_if_inactive,RT<:AbstractFloat} - return Complex{RT}(0) -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Array} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:GenericMemory} - if haskey(seen, prev) - return seen[prev] - end - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - newa = RT(undef, size(prev)) - seen[prev] = newa - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - innerty = Core.Typeof(pv) - @inbounds newa[I] = - EnzymeCore.make_zero(innerty, seen, pv, Val(copy_if_inactive)) - end - end - return newa -end -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT<:Tuple} - return ntuple(length(prev)) do i - Base.@_inline_meta - EnzymeCore.make_zero(RT.parameters[i], seen, prev[i], Val(copy_if_inactive)) - end -end - -@inline function EnzymeCore.make_zero( - ::Type{NamedTuple{A,RT}}, - seen::IdDict, - prev::NamedTuple{A,RT}, - ::Val{copy_if_inactive} = Val(false), -)::NamedTuple{A,RT} where {copy_if_inactive,A,RT} - prevtup = RT(prev) - TT = Core.Typeof(prevtup) # RT can be abstract - return NamedTuple{A,RT}(EnzymeCore.make_zero(TT, seen, prevtup, Val(copy_if_inactive))) -end - -@inline function EnzymeCore.make_zero( - ::Type{Core.Box}, - seen::IdDict, - prev::Core.Box, - ::Val{copy_if_inactive} = Val(false), -) where {copy_if_inactive} - if haskey(seen, prev) - return seen[prev] - end - prev2 = prev.contents - res = Core.Box() - seen[prev] = res - res.contents = EnzymeCore.make_zero(Core.Typeof(prev2), seen, prev2, Val(copy_if_inactive)) - return res -end - -@inline function EnzymeCore.make_zero( - ::Type{RT}, - seen::IdDict, - prev::RT, - ::Val{copy_if_inactive} = Val(false), -)::RT where {copy_if_inactive,RT} - if guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] - end - @assert !Base.isabstracttype(RT) - @assert Base.isconcretetype(RT) - nf = fieldcount(RT) - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT - seen[prev] = y - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - if Base.isconst(RT, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi) - else - setfield!(y, i, xi) - end - end - end - return y - end - if nf == 0 - return prev - end - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:AbstractFloat,S} - return zero(T) -end - -function make_zero_immutable!( - prev::Complex{T}, - seen::S, -)::Complex{T} where {T<:AbstractFloat,S} - return zero(Complex{T}) -end - -function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - ntuple(Val(length(T.parameters))) do i - Base.@_inline_meta - p = prev[i] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end -end - -function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - if guaranteed_const_nongen(NamedTuple{a,b}, nothing) - return prev # unreachable from make_zero! - end - NamedTuple{a,b}(ntuple(Val(length(b.parameters))) do i - Base.@_inline_meta - p = prev[a[i]] - SBT = Core.Typeof(p) - if guaranteed_const_nongen(SBT, nothing) - p # covered by several tests even if not shown in coverage - elseif !ismutabletype(SBT) - make_zero_immutable!(p, seen) - else - EnzymeCore.make_zero!(p, seen) - p - end - end) -end - - -function make_zero_immutable!(prev::T, seen::S)::T where {T,S} - if guaranteed_const_nongen(T, nothing) - return prev # unreachable from make_zero! - end - @assert !ismutabletype(T) - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - flds = Vector{Any}(undef, nf) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - ST = Core.Typeof(xi) - flds[i] = if guaranteed_const_nongen(ST, nothing) - xi - elseif !ismutabletype(ST) - make_zero_immutable!(xi, seen) - else - EnzymeCore.make_zero!(xi, seen) - xi - end - else - nf = i - 1 # rest of tail must be undefined values - break - end - end - return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)::T -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(T) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - prev[] = zero(Complex{T}) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{T,N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, - seen::ST, -)::Nothing where {T<:AbstractFloat,N,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, T}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(T)) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, - seen::ST, -)::Nothing where {T<:AbstractFloat,kind,ST} - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end - fill!(prev, zero(Complex{T})) - return nothing -end -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{Complex{T}}, -)::Nothing where {T<:AbstractFloat} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N})::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::Array{Complex{T},N}, -)::Nothing where {T<:AbstractFloat,N} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Array{T,N}, seen::ST)::Nothing where {T,N,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end - -@static if VERSION < v"1.11-" -else -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T})::Nothing where {T<:AbstractFloat,kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!( - prev::GenericMemory{kind, Complex{T}}, -)::Nothing where {T<:AbstractFloat, kind} - EnzymeCore.make_zero!(prev, nothing) - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::GenericMemory{kind, T}, seen::ST)::Nothing where {T,kind,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - continue - elseif !ismutabletype(SBT) - @inbounds prev[I] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - end - end - return nothing -end -end - - -@inline function EnzymeCore.make_zero!( - prev::Base.RefValue{T}, - seen::ST, -)::Nothing where {T,ST} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev[] - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev[] = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::Core.Box, seen::ST)::Nothing where {ST} - if prev in seen - return nothing - end - push!(seen, prev) - pv = prev.contents - SBT = Core.Typeof(pv) - if guaranteed_const_nongen(SBT, nothing) - return nothing - elseif !ismutabletype(SBT) - prev.contents = make_zero_immutable!(pv, seen) - else - EnzymeCore.make_zero!(pv, seen) - end - return nothing -end - -@inline function EnzymeCore.make_zero!(prev::T, seen::S)::Nothing where {T,S} - if guaranteed_const_nongen(T, nothing) - return nothing - end - if prev in seen - return nothing - end - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) - if nf == 0 - return nothing - end - push!(seen, prev) - for i = 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - SBT = Core.Typeof(xi) - activitystate = active_reg_inner(SBT, (), nothing) - if activitystate == AnyState # guaranteed_const - continue - elseif ismutabletype(T) && !ismutabletype(SBT) - yi = make_zero_immutable!(xi, seen) - if Base.isconst(T, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), prev, i-1, yi) - else - setfield!(prev, i, yi) - end - elseif activitystate == DupState - EnzymeCore.make_zero!(xi, seen) - else - msg = "cannot set $xi to zero in-place, as it contains differentiable values in immutable positions" - throw(ArgumentError(msg)) - end - end - end - return nothing -end - -@inline EnzymeCore.make_zero!(prev) = EnzymeCore.make_zero!(prev, Base.IdSet()) diff --git a/src/typeutils/recursive_add.jl b/src/typeutils/recursive_add.jl index 039f7d3d0c..6de0ca910b 100644 --- a/src/typeutils/recursive_add.jl +++ b/src/typeutils/recursive_add.jl @@ -1,86 +1,116 @@ -# Recursively return x + f(y), where y is active, otherwise x +using .RecursiveMaps: RecursiveMaps, recursive_map, recursive_map! -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T,F,F2} - if forcelhs(T) - return x - end - splatnew(T, ntuple(Val(fieldcount(T))) do i - Base.@_inline_meta - prev = getfield(x, i) - next = getfield(y, i) - recursive_add(prev, next, f, forcelhs) - end) -end +""" + recursive_add(x::T, y::T, f=identity, forcelhs=guaranteed_const_nongen) + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively construct `z::T` such that `zi = xi + f(yi)` where `zi`, `xi`, and `yi` are +corresponding values from `z`, `x`, and `y`. In other words, this is a recursive +generalization of `x .+ f.(y)`. -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:AbstractFloat,F,F2} - if forcelhs(T) - return x +The function `f` must return values of the same type as its argument. + +The optional argument `forcelhs` takes a callable such that if `forcelhs(S) == true`, values +`zi::S` will be set to `zi = xi`. The default returns true for non-differentiable types, +such that `zi = xi + f(yi)` applies to differentiable values, while `zi = xi` applies to +non-differentiable values. +""" +function recursive_add( + x::T, y::T, f::F=identity, forcelhs::L=guaranteed_const_nongen +) where {T,F,L} + function addf(xi::S, yi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return ((xi + f(yi))::S,) end - return x + f(y) + return only(recursive_map(addf, Val(1), (x, y), Val(false), forcelhs))::T end -@inline function recursive_add( - x::T, - y::T, - f::F = identity, - forcelhs::F2 = guaranteed_const, -) where {T<:Complex,F,F2} - if forcelhs(T) - return x +""" + accumulate_seen!(f, seen::IdDict) + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively accumulate from values into keys, generalizing key .+= f.(value), for each +key-value pair in `seen::IdDict` where each key must be a mutable object or non-isbits +vector type instance mappping to another object of the same type and structure. Typically +`seen` is populated by `make_zero` (or some other single-argument invocation of +`recursive_map`), mapping components of its argument to the corresponding component of the +returned value. + +The recursion stops at instances of types that are themselves cached by `make_zero` +(`recursive_map`), as these objects should have their own entries in `seen`. The recursion +also stops at inactive objects that not be zeroed by `make_zero`. +""" +function accumulate_seen!(f::F, seen::IdDict) where {F} + for (k, v) in seen + _accumulate_seen_item!(f, k, v) end - return x + f(y) + return nothing end -@inline mutable_register(::Type{T}) where {T<:Integer} = true -@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false -@inline mutable_register(::Type{T}) where {T<:Tuple} = false -@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false -@inline mutable_register(::Type{Core.Box}) = true -@inline mutable_register(::Type{T}) where {T<:Array} = true -@inline mutable_register(::Type{T}) where {T} = ismutabletype(T) - -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F} - if !mutable_register(T) - for I in eachindex(x) - prev = x[I] - @inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register) - end +function _accumulate_seen_item!(f::F, k::T, v::T) where {F,T} + function addf!!(ki::S, vi::S) where {S} + @assert EnzymeCore.isvectortype(S) + return ((ki .+ f.(vi))::S,) + end + function addf!!(ki::S, _ki::S, vi::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert ki === _ki + ki .+= f.(vi) + return (ki::S,) + end + @inline function isinactive_or_cachedtype(::Type{T}) where {T} + return guaranteed_const_nongen(T) || RecursiveMaps.iscachedtype(T) + end + RecursiveMaps.check_nonactive(T) + if !guaranteed_const_nongen(T) + newks = RecursiveMaps.recursive_map_inner( + nothing, addf!!, (k,), (k, v), Val(false), isinactive_or_cachedtype + ) + @assert only(newks) === k end + return nothing end +""" + accumulate_into!(into::T, from::T) -# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y) -@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F} - recursive_accumulate(x.contents, y.contents, seen, f) -end +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recursively accumulate from `from` into `into` and zero `from`, such that `into_i += from_i` +and `from_i = 0`, where `into_i` and `from_i` are corresponding values within `into` and +`from`. In other words, this is a recursive generalization of -@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F} - @assert !Base.isabstracttype(T) - @assert Base.isconcretetype(T) - nf = fieldcount(T) +```julia +into .+= from +from .= 0 +``` - for i = 1:nf - if isdefined(x, i) - xi = getfield(x, i) - ST = Core.Typeof(xi) - if !mutable_register(ST) - @assert ismutable(x) - yi = getfield(y, i) - nexti = recursive_add(xi, yi, f, mutable_register) - setfield!(x, i, nexti) - end - end +The accumulation and zeroing is only applied to differentiable values; non-differentiable +values within both `into` and `from` are left untouched. +""" +function accumulate_into!(into::T, from::T) where {T} + # may not show in coverage but both base cases are covered via deepcopy custom rule tests + function accumulate_into!!(into_i::S, from_i::S) where {S} + @assert EnzymeCore.isvectortype(S) + return ((into_i + from_i)::S, convert(S, zero(from_i))::S) + end + function accumulate_into!!(into_i::S, from_i::S, _into_i::S, _from_i::S) where {S} + @assert !EnzymeCore.isscalartype(S) + @assert EnzymeCore.isvectortype(S) + @assert (into_i === _into_i) && (from_i === _from_i) + into_i .+= from_i + fill!(from_i, false) + return (into_i::S, from_i::S) end + recursive_map!(accumulate_into!!, (into, from), (into, from)) + return nothing end diff --git a/src/typeutils/recursive_maps.jl b/src/typeutils/recursive_maps.jl new file mode 100644 index 0000000000..49bb669f8d --- /dev/null +++ b/src/typeutils/recursive_maps.jl @@ -0,0 +1,672 @@ +module RecursiveMaps + +using EnzymeCore: EnzymeCore, isvectortype, isscalartype +using ..Compiler: guaranteed_const_nongen, guaranteed_nonactive_nongen + +### traits defining active leaf types for recursive_map +@inline EnzymeCore.isvectortype(::Type{T}) where {T} = isscalartype(T) +@inline function EnzymeCore.isvectortype(::Type{<:DenseArray{U}}) where {U} + return isbitstype(U) && isscalartype(U) +end + +@inline EnzymeCore.isscalartype(::Type) = false +@inline EnzymeCore.isscalartype(::Type{T}) where {T<:AbstractFloat} = isconcretetype(T) +@inline function EnzymeCore.isscalartype(::Type{Complex{T}}) where {T<:AbstractFloat} + return isconcretetype(T) +end + +### recursive_map: walk arbitrary objects and map a function over scalar and vector leaves +""" + ys = recursive_map( + [seen::Union{Nothing,IdDict},] + f, + ::Val{Nout} + xs::NTuple{Nin,T}, + ::Val{copy_if_inactive}=Val(false), + isinactivetype=guaranteed_const_nongen, + )::T + newys = recursive_map( + [seen::Union{Nothing,IdDict},] + f, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + ::Val{copy_if_inactive}=Val(false), + isinactivetype=guaranteed_const_nongen, + )::T + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping the +function `f` over every differentiable value encountered and building `Nout` new objects +`(y1::T, ...)` from the resulting values `(y1_i, ...) = f(x1_i, ..., xNin_i)`. Only +`Nout == 1` and `Nout == 2` are supported. + +The trait `EnzymeCore.isvectortype`(@ref) determines which values are considered +differentiable leaf nodes at which recursion terminates and `f` is invoked. See the +docstring for [`EnzymeCore.isvectortype`](@ref) and the related +[`EnzymeCore.isscalartype`](@ref) for more information. + +A tuple of existing objects `ys = (y1::T, ...)` can be passed, in which case the `ys` are +updated "partially-in-place": any parts of the `ys` that are mutable or non-differentiable +are reused in the returned object tuple `newys`, while immutable differentiable parts are +handled out-of-place as if the `ys` were not passed (this can be seen as a recursive +generalization of the BangBang.jl idiom). If `T` itself is a mutable type, the `ys` are +modified in-place and returned, such that `newys === ys`. + +The recursion and mapping operates on the structure of `T` as defined by struct fields and +plain array elements, not on the values provided through an iteration or array interface. +For example, given a structured matrix wrapper or sparse array type, this function recurses +into the struct type and the plain arrays held within, rather than operating on the array +that the type notionally represents. + +# Arguments + +* `seen::Union{IdDict,Nothing}` (optional): Dictionary for tracking object identity as + needed to construct `y` such that its internal graph of object references is identical to + that of the `xs`, including cycles (i.e., recursive substructures) and multiple paths to + the same objects. If not provided, an `IdDict` will be allocated internally if required. + + If `nothing` is provided, object identity is not tracked. In this case, objects with + multiple references are duplicated such that the `ys`s object reference graph becomes a + tree, cycles lead to infinite recursion and stack overflow, and `copy_if_inactive == true` + will likely cause errors. This is useful only in specific cases. + +* `f`: Function mapping leaf nodes within the `xs` to the corresponding leaf nodes in the + `ys`, that is, `(y1_i, ...) = f(x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}`. The function + `f` must be applicable to the type of every leaf node, and must return a tuple of values + of the same type as its arguments. + + When an existing object tuple `ys` is passed and contains leaf nodes of a non-isbits + non-scalar type `U`, `f` should also have a partially-in-place method + `(newy1_i, ...) === f(y1_i::U, ..., yNout_i::U, x1_i::U, ..., xNin_i::U)::NTuple{Nout,U}` + that modifies and reuses any mutable parts of the `yj_i`; in particular, if `U` is a + mutable type, this method should return `newyj_i === yj_i`. If a non-isbits type `U` + should always be handled using the out-of-place signature, extend + [`EnzymeCore.isscalartype`](@ref) such that `isscalartype(U) == true`. + + See [`EnzymeCore.isvectortype`](@ref) and [`EnzymeCore.isscalartype`](@ref) for more + details about leaf types and scalar types. + +* `::Val{Nout}` or `ys::NTuple{Nout,T}`: For out-of-place operation, pass `Val(Nout)` where + `Nout in (1, 2)` is the length of the tuple returned by `f`, that is, the length of the + expected return value `ys` (this is required; `Nout` never inferred). For + partially-in-place operation, pass the existing tuple `ys::NTuple{Nout,T}` containing the + values to be modified. + +* `xs::NTuple{N,T}`: Tuple of `N` objects of the same type `T` over which `f` is mapped. + + The first object `x1 = first(xs)` is the reference for graph structure and + non-differentiable values when constructing the returned object. In particular: + * When `ys` is not passed, the returned objects take any non-differentiable parts from + `x1`. (When `ys` is passed, its non-differentiable parts are kept unchanged in the + returned object, unless they are not initialized, in which case they are taken from + `x1`.) + * The graph of object references in `x1` is the one which is reproduced in the returned + object. For each instance of multiple paths and cycles within `x1`, the same structure + must be present in the other objects `x2, ..., xN`, otherwise the corresponding values + in the `ys` would not be uniquely defined. However, `x2, ..., xN` may contain multiple + paths or cycles that are not present in `x1`; these do not affect the structure of `ys`. + * If any values within `x1` are not initialized (that is, struct fields are undefined or + array elements are unassigned), they are left uninitialized in the returned object. If + any such values are mutable and `ys` is passed, the corresponding value in `y` must not + already be initialized, since initialized values cannot be nulled. Conversely, for every + value in `x1` that is initialized, the corresponding values in `x2, ..., xN` must also + be initialized, such that the corresponding values of the `ys` can be computed (however, + values in `x2, ..., xN` can be initialized while the corresponding value in `x1` is not; + such values are ignored.) + +* `::Val{copy_if_inactive::Bool}` (optional): When a non-differentiable part of `x1` is + included in the returned object, either because an object tuple `ys` is not passed or this + part of the `ys` is not initialized, `copy_if_inactive` determines how: if + `copy_if_inactive == false`, it is shared as `yj_i = x1_i`; if `copy_if_inactive == true`, + it is deep-copied, more-or-less as `yj_i = deepcopy(x1_i)` (the difference is that when + `x1` has several non-differentiable parts, object identity is tracked across the multiple + deep-copies such that the object reference graph is reproduced also within the inactive + parts.) + +* `isinactivetype` (optional): Callable mapping types to `Bool` to determines whether the + type should be treated according to `copy_if_inactive` (`true`) or recursed into (`false`). +""" +function recursive_map end + +## type alias for unified handling of out-of-place and partially-in-place recursive_map +const YS{Nout,T} = Union{Val{Nout},NTuple{Nout,T}} +@inline hasvalues(::T) where {T<:YS} = hasvalues(T) +@inline hasvalues(::Type{<:Val}) = false +@inline hasvalues(::Type{<:NTuple}) = true + +## main entry point: set default arguments, allocate IdDict if needed, exit early if possible +function recursive_map( + f::F, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val(false), + isinactivetype::L=guaranteed_const_nongen, +) where {F,Nout,Nin,T,L} + check_nout(ys) + newys = if isinactivetype(T) + recursive_map_inactive(nothing, ys, xs, copy_if_inactive) + elseif isvectortype(T) || isbitstype(T) + recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_inner(IdDict(), f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +## recursive methods +function recursive_map( + seen::Union{Nothing,IdDict}, + f::F, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive::Val=Val(false), + isinactivetype::L=guaranteed_const_nongen, +) where {F,Nout,Nin,T,L} + # determine whether to continue recursion, copy/share, or retrieve from cache + check_nout(ys) + newys = if isinactivetype(T) + recursive_map_inactive(seen, ys, xs, copy_if_inactive) + elseif isbitstype(T) # no object identity to to track in this branch + recursive_map_inner(nothing, f, ys, xs, copy_if_inactive, isinactivetype) + elseif hascache(seen, xs) + getcached(seen, Val(Nout), xs) + else + recursive_map_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_inner( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # forward to appropriate handler for leaf vs. mutable vs. immutable type + @assert !isabstracttype(T) + @assert isconcretetype(T) + newys = if isvectortype(T) + recursive_map_leaf(seen, f, ys, xs) + elseif ismutabletype(T) + recursive_map_mutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_immutable(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_mutable( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + @assert ismutabletype(T) + if !hasvalues(ys) && !(T <: DenseArray) && all(isbitstype, fieldtypes(T)) + # fast path for out-of-place handling when all fields are bitstypes, which rules + # out undefined fields and circular references + newys = recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + maybecache!(seen, newys, xs) + else + newys = if hasvalues(ys) + ys + else + x1 = first(xs) + ntuple(_ -> (@inline; _similar(x1)), Val(Nout)) + end + maybecache!(seen, newys, xs) + recursive_map_mutable_inner!(seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + end + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_mutable_inner!( + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T<:DenseArray,L} + if (Nout == 1) && isbitstype(eltype(T)) + newy = only(newys) + if hasvalues(ys) + y = only(ys) + broadcast!(newy, y, xs...) do y_i, xs_i... + only(recursive_map(nothing, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)) + end + else + broadcast!(newy, xs...) do xs_i... + only(recursive_map(nothing, f, Val(1), xs_i, copy_if_inactive, isinactivetype)) + end + end + else + @inbounds for i in eachindex(newys..., xs...) + recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + end + end + return nothing +end + +@generated function recursive_map_mutable_inner!( + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T,L} + return quote + @inline + Base.Cartesian.@nexprs $(fieldcount(T)) i -> @inbounds begin + recursive_map_item!(i, seen, f, newys, ys, xs, copy_if_inactive, isinactivetype) + end + return nothing + end +end + +@inline function recursive_map_immutable( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + @assert !ismutabletype(T) + nf = fieldcount(T) + if nf == 0 # nothing to do (also no known way to hit this branch) + newys = recursive_map_inactive(seen, ys, xs, Val(false)) + else + newys = if isinitialized(first(xs), nf) # fast path when all fields are defined + check_allinitialized(Base.tail(xs), nf) + recursive_map_new(seen, f, ys, xs, copy_if_inactive, isinactivetype) + else + recursive_map_immutable_inner(seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + # maybecache! _should_ be a no-op here; call it anyway for consistency + maybecache!(seen, newys, xs) + end + return newys::NTuple{Nout,T} +end + +@generated function recursive_map_immutable_inner( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + nf = fieldcount(T) + return quote + @inline + x1, xtail = first(xs), Base.tail(xs) + fields = Base.@ntuple $Nout _ -> Vector{Any}(undef, $(nf - 1)) + Base.Cartesian.@nexprs $(nf - 1) i -> begin # unrolled loop over struct fields + @inbounds if isinitialized(x1, i) + check_allinitialized(xtail, i) + newys_i = recursive_map_item( + i, seen, f, ys, xs, copy_if_inactive, isinactivetype + ) + Base.Cartesian.@nexprs $Nout j -> (fields[j][i] = newys_i[j]) + else + return new_structvs(T, fields, i - 1) + end + end + @assert !isinitialized(x1, $nf) + return new_structvs(T, fields, $(nf - 1)) + end +end + +@generated function recursive_map_new( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # direct construction of fully initialized non-cyclic structs + nf = fieldcount(T) + return quote + @inline + Base.Cartesian.@nexprs $nf i -> @inbounds begin + newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + end + newys = Base.@ntuple $Nout j -> begin + $(Expr(:splatnew, :T, :(Base.@ntuple $nf i -> newys_i[j]))) + end + return newys::NTuple{Nout,T} + end +end + +Base.@propagate_inbounds function recursive_map_item!( + i, + seen, + f::F, + newys::NTuple{Nout,T}, + ys::YS{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactive, + isinactivetype::L, +) where {F,Nout,Nin,T,L} + if isinitialized(first(xs), i) + check_allinitialized(Base.tail(xs), i) + newys_i = recursive_map_item(i, seen, f, ys, xs, copy_if_inactive, isinactivetype) + setitems!(newys, i, newys_i) + elseif hasvalues(ys) + check_allinitialized(ys, i, false) + end + return nothing +end + +Base.@propagate_inbounds function recursive_map_item( + i, seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactive, isinactivetype::L +) where {F,Nout,Nin,T,L} + # recurse into the xs and apply recursive_map to items with index i + xs_i = getitems(xs, i) + newys_i = if hasvalues(ys) && isinitialized(first(ys), i) + check_allinitialized(Base.tail(ys), i) + ys_i = getitems(ys, i) + recursive_map_barrier!!(seen, f, ys_i..., copy_if_inactive, isinactivetype, xs_i...) + else + recursive_map_barrier(seen, f, Val(Nout), copy_if_inactive, isinactivetype, xs_i...) + end + return newys_i +end + +# function barriers such that abstractly typed items trigger minimal runtime dispatch +function recursive_map_barrier( + seen, f::F, ::Val{Nout}, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} +) where {F,Nout,Nin,ST,L} + return recursive_map( + seen, f, Val(Nout), xs_i, copy_if_inactive, isinactivetype + )::NTuple{Nout,ST} +end + +function recursive_map_barrier!!( + seen, f::F, y_i::ST, copy_if_inactive::Val, isinactivetype::L, xs_i::Vararg{ST,Nin} +) where {F,Nin,ST,L} + return recursive_map(seen, f, (y_i,), xs_i, copy_if_inactive, isinactivetype)::NTuple{1,ST} +end + +function recursive_map_barrier!!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + seen, + f::F, + y1_i::ST, + y2_i::ST, + copy_if_inactive::Val, + isinactivetype::L, + xs_i::Vararg{ST,Nin} +) where {F,Nin,ST,L} + ys_i = (y1_i, y2_i) + return recursive_map(seen, f, ys_i, xs_i, copy_if_inactive, isinactivetype)::NTuple{2,ST} +end + +## recursion base case handlers +@inline function recursive_map_leaf( + seen, f::F, ys::YS{Nout,T}, xs::NTuple{Nin,T} +) where {F,Nout,Nin,T} + # apply the mapped function to leaf values + if !hasvalues(ys) || isbitstype(T) || isscalartype(T) + newys = f(xs...)::NTuple{Nout,T} + else # !isbitstype(T) + newys = f(ys..., xs...)::NTuple{Nout,T} + if ismutabletype(T) + @assert newys === ys + end + end + maybecache!(seen, newys, xs) + return newys::NTuple{Nout,T} +end + +@inline function recursive_map_inactive( + _, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, ::Val{copy_if_inactive} +) where {Nout,Nin,T,copy_if_inactive} + return ys::NTuple{Nout,T} +end + +@inline function recursive_map_inactive( + seen, ::Val{Nout}, (x1,)::NTuple{Nin,T}, ::Val{copy_if_inactive} +) where {Nout,Nin,T,copy_if_inactive} + @inline + y = if copy_if_inactive && !isbitstype(T) + if isnothing(seen) + deepcopy(x1) + else + Base.deepcopy_internal(x1, seen) + end + else + x1 + end + return ntuple(_ -> (@inline; y), Val(Nout))::NTuple{Nout,T} +end + +### recursive_map!: fully in-place wrapper around recursive_map +""" + recursive_map!( + [seen::Union{Nothing,IdDict},] + f!!, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + [::Val{copy_if_inactive},] + )::Nothing + +!!! warning + Internal function, documented for developer convenience but not covered by semver API + stability guarantees + +Recurse through `Nin` objects `xs = (x1::T, x2::T, ..., xNin::T)` of the same type, mapping +the function `f!!` over every differentiable value encountered and updating `(y1::T, ...)` +in-place with the resulting values. + +This is a simple wrapper that verifies that `T` is a type where all differentiable values +can be updated in-place, calls `recursive_map`, and verifies that the returned value is +indeed identically the same tuple `ys`. See [`recursive_map`](@ref) for details. +""" +function recursive_map! end + +function recursive_map!( + f!!::F, ys::NTuple{Nout,T}, xs::NTuple{Nin,T}, copy_if_inactives::Vararg{Val,M} +) where {F,Nout,Nin,T,M} + @assert M <= 1 + check_nonactive(T) + newys = recursive_map(f!!, ys, xs, copy_if_inactives...) + @assert newys === ys + return nothing +end + +function recursive_map!( + seen::Union{Nothing,IdDict}, + f!!::F, + ys::NTuple{Nout,T}, + xs::NTuple{Nin,T}, + copy_if_inactives::Vararg{Val,M}, +) where {F,Nout,Nin,T,M} + @assert M <= 1 + check_nonactive(T) + newys = recursive_map(seen, f!!, ys, xs, copy_if_inactives...) + @assert newys === ys + return nothing +end + +### recursive_map helpers +@generated function new_structvs(::Type{T}, fields::NTuple{N,Vector{Any}}, nfields_) where {T,N} + return quote + @inline + return Base.@ntuple $N j -> begin + ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, fields[j], nfields_)::T + end + end +end + +@inline _similar(::T) where {T} = ccall(:jl_new_struct_uninit, Any, (Any,), T)::T +@inline _similar(x::T) where {T<:DenseArray} = similar(x)::T +Base.@propagate_inbounds isinitialized(x, i) = isdefined(x, i) +Base.@propagate_inbounds isinitialized(x::DenseArray, i) = isassigned(x, i) +Base.@propagate_inbounds getitem(x, i) = getfield(x, i) +Base.@propagate_inbounds getitem(x::DenseArray, i) = x[i] +Base.@propagate_inbounds setitem!(x, i, v) = setfield_force!(x, i, v) +Base.@propagate_inbounds setitem!(x::DenseArray, i, v) = (x[i] = v; nothing) + +Base.@propagate_inbounds function setfield_force!(x::T, i, v) where {T} + if Base.isconst(T, i) + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i - 1, v) + else + setfield!(x, i, v) + end + return nothing +end + +Base.@propagate_inbounds function getitems((x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i) where {T,N} + return (getitem(x1, i), getitems(xtail, i)...) +end + +Base.@propagate_inbounds getitems((x1,)::Tuple{T}, i) where {T} = (getitem(x1, i),) + +Base.@propagate_inbounds function setitems!( # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, (v1, vtail...)::Tuple{ST,ST,Vararg{ST,N}} +) where {T,ST,N} + setitem!(x1, i, v1) + setitems!(xtail, i, vtail) + return nothing +end + +Base.@propagate_inbounds function setitems!((x1,)::Tuple{T}, i, (v1,)::Tuple{ST}) where {T,ST} + setitem!(x1, i, v1) + return nothing +end + +## cache (seen) helpers +@inline function iscachedtype(::Type{T}) where {T} + # cache all mutable types and any non-isbits types that are also leaf types + return ismutabletype(T) || ((!isbitstype(T)) && isvectortype(T)) +end + +@inline shouldcache(::IdDict, ::Type{T}) where {T} = iscachedtype(T) +@inline shouldcache(::Nothing, ::Type{T}) where {T} = false + +@inline function maybecache!(seen, newys::NTuple{Nout,T}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} + if shouldcache(seen, T) + seen[x1] = if (Nout == 1) && (Nin == 1) + only(newys) + else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + (newys..., xtail...) + end + end + return nothing +end + +@inline function hascache(seen, (x1,)::NTuple{Nin,T}) where {Nin,T} + return shouldcache(seen, T) ? haskey(seen, x1) : false +end + +@inline function getcached(seen::IdDict, ::Val{Nout}, (x1, xtail...)::NTuple{Nin,T}) where {Nout,Nin,T} + newys = if (Nout == 1) && (Nin == 1) + (seen[x1]::T,) + else # may not show in coverage but is covered via accumulate_into! TODO: ensure coverage via VectorSpace once implemented + cache = seen[x1]::NTuple{(Nout + Nin - 1),T} + cachedtail = cache[(Nout+1):end] + check_identical(cachedtail, xtail) # check compatible layout + cache[1:Nout] + end + return newys::NTuple{Nout,T} +end + +## argument validation +@inline function check_nout(::YS{Nout}) where {Nout} + if Nout > 2 + throw_nout() + end +end + +Base.@propagate_inbounds function check_initialized(x, i, initialized=true) + if isinitialized(x, i) != initialized + throw_initialized() # TODO: hit this when VectorSpace implemented + end + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( # TODO: hit this when VectorSpace implemented + (x1, xtail...)::Tuple{T,T,Vararg{T,N}}, i, initialized=true +) where {T,N} + check_initialized(x1, i, initialized) + check_allinitialized(xtail, i, initialized) + return nothing +end + +Base.@propagate_inbounds function check_allinitialized( + (x1,)::Tuple{T}, i, initialized=true +) where {T} + check_initialized(x1, i, initialized) + return nothing +end + +Base.@propagate_inbounds check_allinitialized(::Tuple{}, i, initialized=true) = nothing + +@inline function check_identical(u, v) # TODO: hit this when VectorSpace implemented + if u !== v + throw_identical() + end + return nothing +end + +@inline function check_nonactive(::Type{T}) where {T} + if !guaranteed_nonactive_nongen(T) + throw_nonactive() + end + return nothing +end + +# TODO: hit all of these via check_* when VectorSpace implemented +@noinline function throw_nout() + throw(ArgumentError("recursive_map(!) only supports mapping to 1 or 2 outputs")) +end + +@noinline function throw_initialized() + msg = "recursive_map(!) called on objects whose undefined fields/unassigned elements " + msg *= "don't line up" + throw(ArgumentError(msg)) +end + +@noinline function throw_identical() + msg = "recursive_map(!) called on objects whose layout don't match" + throw(ArgumentError(msg)) +end + +@noinline function throw_nonactive() + msg = "recursive_map! called on objects containing immutable differentiable values" + throw(ArgumentError(msg)) +end + +### EnzymeCore.make_zero(!) implementation +function EnzymeCore.make_zero(prev::T, copy_if_inactives::Vararg{Val,M}) where {T,M} + @assert M <= 1 + new = if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback + # guaranteed_const has precedence over isvectortype for consistency with recursive_map + convert(T, zero(prev)) # convert because zero(prev)::T may fail when eltype(T) is abstract + else + only(recursive_map(_make_zero!!, Val(1), (prev,), copy_if_inactives...)) + end + return new::T +end + +function EnzymeCore.make_zero!(val::T, seens::Vararg{IdDict,M}) where {T,M} + @assert M <= 1 + @assert !isscalartype(T) # not appropriate for in-place handler + if iszero(M) && !guaranteed_const_nongen(T) && isvectortype(T) # fallback + # isinactivetype has precedence over isvectortype for consistency with recursive_map + fill!(val, false) + else + recursive_map!(seens..., _make_zero!!, (val,), (val,)) + end + return nothing +end + +function _make_zero!!(prev::T) where {T} + @assert isvectortype(T) # otherwise infinite loop + return (EnzymeCore.make_zero(prev)::T,) +end + +function _make_zero!!(val::T, _val::T) where {T} + @assert !isscalartype(T) # not appropriate for in-place handler + @assert isvectortype(T) # otherwise infinite loop + @assert val === _val + EnzymeCore.make_zero!(val) + return (val::T,) +end + +# alternative entry point for passing custom IdDict +function EnzymeCore.make_zero( + ::Type{T}, seen::IdDict, prev::T, copy_if_inactives::Vararg{Val,M} +) where {T,M} + @assert M <= 1 + return only(recursive_map(seen, _make_zero!!, Val(1), (prev,), copy_if_inactives...))::T +end + +end # module RecursiveMaps diff --git a/test/Project.toml b/test/Project.toml index fbc6d754fe..667d94ba1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/make_zero.jl b/test/recursive_maps.jl similarity index 50% rename from test/make_zero.jl rename to test/recursive_maps.jl index cbe2f2159f..fc903e5139 100644 --- a/test/make_zero.jl +++ b/test/recursive_maps.jl @@ -1,12 +1,16 @@ -module MakeZeroTests +module RecursiveMapTests using Enzyme +using JLArrays +using Logging using StaticArrays using Test # Universal getters/setters for built-in and custom containers/wrappers getx(w::Base.RefValue) = w[] getx(w::Core.Box) = w.contents +getx(w::JLArray) = JLArrays.@allowscalar first(w) +gety(w::JLArray) = JLArrays.@allowscalar last(w) getx(w) = first(w) gety(w) = last(w) @@ -87,20 +91,27 @@ gety(a::MutableDualWrapper) = a.y setx!(a::MutableDualWrapper, x) = (a.x = x) sety!(a::MutableDualWrapper, y) = (a.y = y) -struct Incomplete{T} +struct Incomplete{T,U} s::String x::Float64 w::T + y::U # possibly not initialized z # not initialized - Incomplete(s, x, w) = new{typeof(w)}(s, x, w) + Incomplete(s, x, w) = new{typeof(w),Any}(s, x, w) + Incomplete(s, x, w, y) = new{typeof(w),typeof(y)}(s, x, w, y) end function Base.:(==)(a::Incomplete, b::Incomplete) (a === b) && return true ((a.s == b.s) && (a.x == b.x) && (a.w == b.w)) || return false - if isdefined(a, :z) && isdefined(b, :z) - (a.z == b.z) || return false - elseif isdefined(a, :z) || isdefined(b, :z) + if isdefined(a, :y) && isdefined(b, :y) + (a.w == b.w) || return false + if isdefined(a, :z) && isdefined(b, :z) + (a.z == b.z) || return false + elseif isdefined(a, :z) || isdefined(b, :z) + return false + end + elseif isdefined(a, :y) || isdefined(b, :y) return false end return true @@ -132,40 +143,27 @@ function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) return true end -mutable struct CustomVector{T} <: AbstractVector{T} +mutable struct CustomVector{T} data::Vector{T} end Base.:(==)(a::CustomVector, b::CustomVector) = (a === b) || (a.data == b.data) -function Enzyme.EnzymeCore.make_zero( - ::Type{CV}, seen::IdDict, prev::CV, ::Val{copy_if_inactive} -) where {CV<:CustomVector{<:AbstractFloat},copy_if_inactive} +function Enzyme.EnzymeCore.isvectortype(::Type{CustomVector{T}}) where {T} + return Enzyme.EnzymeCore.isscalartype(T) +end + +function Enzyme.EnzymeCore.make_zero(prev::CV) where {CV<:CustomVector{<:AbstractFloat}} @info "make_zero(::CustomVector)" - if haskey(seen, prev) - return seen[prev] - end - new = CustomVector(zero(prev.data))::CV - seen[prev] = new - return new + return CustomVector(zero(prev.data))::CV end -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}, seen)::Nothing +function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) @info "make_zero!(::CustomVector)" - if !isnothing(seen) - if prev in seen - return nothing - end - push!(seen, prev) - end fill!(prev.data, false) return nothing end -function Enzyme.EnzymeCore.make_zero!(prev::CustomVector{<:AbstractFloat}) - return Enzyme.EnzymeCore.make_zero!(prev, nothing) -end - struct WithIO{F} # issue 2091 v::Vector{Float64} callback::F @@ -186,55 +184,60 @@ macro test_noerr(expr) end end -const scalartypes = [Float32, ComplexF32, Float64, ComplexF64] +const scalartypes = [Float32, ComplexF32, Float64, ComplexF64, BigFloat, Complex{BigFloat}] -const inactivetup = ("a", Empty(), MutableEmpty()) +const inactivebits = (1, Empty()) +const inactivetup = (inactivebits, "a", MutableEmpty()) const inactivearr = [inactivetup] const wrappers = [ - (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true), - (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true), - (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true), + (name="Tuple{X}", f=tuple, N=1, mutable=false, typed=true, bitsonly=false), + (name="@NamedTuple{x::X}", f=(NamedTuple{(:x,)} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), + (name="struct{X}", f=Wrapper, N=1, mutable=false, typed=true, bitsonly=false), - (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false), - (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false), + (name="@NamedTuple{x}", f=(@NamedTuple{x} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), + (name="struct{Any}", f=Wrapper{Any}, N=1, mutable=false, typed=false, bitsonly=false), - (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true), - (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true), - (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true), + (name="Array{X}", f=(x -> [x]), N=1, mutable=true, typed=true, bitsonly=false), + (name="Base.RefValue{X}", f=Ref, N=1, mutable=true, typed=true, bitsonly=false), + (name="mutable struct{X}", f=MutableWrapper, N=1, mutable=true, typed=true, bitsonly=false), - (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false), - (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false), - (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false), - (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false), + (name="Array{Any}", f=(x -> Any[x]), N=1, mutable=true, typed=false, bitsonly=false), + (name="Base.RefValue{Any}", f=Ref{Any}, N=1, mutable=true, typed=false, bitsonly=false), + (name="Core.Box", f=Core.Box, N=1, mutable=true, typed=false, bitsonly=false), + (name="mutable struct{Any}", f=MutableWrapper{Any}, N=1, mutable=true, typed=false, bitsonly=false), - (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true), - (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true), - (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true), + (name="Tuple{X,Y}", f=tuple, N=2, mutable=false, typed=true, bitsonly=false), + (name="@NamedTuple{x::X,y::Y}", f=(NamedTuple{(:x, :y)} ∘ tuple), N=2, mutable=false, typed=true, bitsonly=false), + (name="struct{X,Y}", f=DualWrapper, N=2, mutable=false, typed=true, bitsonly=false), - (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial), - (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial), + (name="@NamedTuple{x,y::Y}", f=((x, y) -> @NamedTuple{x,y::typeof(y)}((x, y))), N=2, mutable=false, typed=:partial, bitsonly=false), + (name="struct{Any,Y}", f=DualWrapper{Any}, N=2, mutable=false, typed=:partial, bitsonly=false), - (name="@NamedTuple{x,y}", f=@NamedTuple{x,y} ∘ tuple, N=2, mutable=false, typed=false), - (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false), + (name="@NamedTuple{x,y}", f=(@NamedTuple{x,y} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), + (name="struct{Any}", f=DualWrapper{Any,Any}, N=2, mutable=false, typed=false, bitsonly=false), - (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true), + (name="mutable struct{X,Y}", f=MutableDualWrapper, N=2, mutable=true, typed=true, bitsonly=false), - (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted), - (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial), + (name="Array{promote_type(X,Y)}", f=((x, y) -> [x, y]), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="mutable struct{Any,Y}", f=MutableDualWrapper{Any}, N=2, mutable=true, typed=:partial, bitsonly=false), - (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false), - (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false), + (name="Array{Any}", f=((x, y) -> Any[x, y]), N=2, mutable=true, typed=false, bitsonly=false), + (name="mutable struct{Any,Any}", f=MutableDualWrapper{Any,Any}, N=2, mutable=true, typed=false, bitsonly=false), # StaticArrays extension - (name="SVector{1,X}", f=SVector{1} ∘ tuple, N=1, mutable=false, typed=true), - (name="SVector{1,Any}", f=SVector{1,Any} ∘ tuple, N=1, mutable=false, typed=false), - (name="MVector{1,X}", f=MVector{1} ∘ tuple, N=1, mutable=true, typed=true), - (name="MVector{1,Any}", f=MVector{1,Any} ∘ tuple, N=1, mutable=true, typed=false), - (name="SVector{2,promote_type(X,Y)}", f=SVector{2} ∘ tuple, N=2, mutable=false, typed=:promoted), - (name="SVector{2,Any}", f=SVector{2,Any} ∘ tuple, N=2, mutable=false, typed=false), - (name="MVector{2,promote_type(X,Y)}", f=MVector{2} ∘ tuple, N=2, mutable=true, typed=:promoted), - (name="MVector{2,Any}", f=MVector{2,Any} ∘ tuple, N=2, mutable=true, typed=false), + (name="SVector{1,X}", f=(SVector{1} ∘ tuple), N=1, mutable=false, typed=true, bitsonly=false), + (name="SVector{1,Any}", f=(SVector{1,Any} ∘ tuple), N=1, mutable=false, typed=false, bitsonly=false), + (name="MVector{1,X}", f=(MVector{1} ∘ tuple), N=1, mutable=true, typed=true, bitsonly=false), + (name="MVector{1,Any}", f=(MVector{1,Any} ∘ tuple), N=1, mutable=true, typed=false, bitsonly=false), + (name="SVector{2,promote_type(X,Y)}", f=(SVector{2} ∘ tuple), N=2, mutable=false, typed=:promoted, bitsonly=false), + (name="SVector{2,Any}", f=(SVector{2,Any} ∘ tuple), N=2, mutable=false, typed=false, bitsonly=false), + (name="MVector{2,promote_type(X,Y)}", f=(MVector{2} ∘ tuple), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="MVector{2,Any}", f=(MVector{2,Any} ∘ tuple), N=2, mutable=true, typed=false, bitsonly=false), + + # GPUArrays extension + (name="JLArray{X}", f=(x -> JLArray([x])), N=1, mutable=true, typed=true, bitsonly=true), + (name="JLArray{promote_type(X,Y)}", f=((x, y) -> JLArray([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=true), ] @static if VERSION < v"1.11-" @@ -242,10 +245,10 @@ else _memory(x::Vector) = Memory{eltype(x)}(x) push!( wrappers, - (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true), - (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false), - (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted), - (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false), + (name="Memory{X}", f=(x -> _memory([x])), N=1, mutable=true, typed=true, bitsonly=false), + (name="Memory{Any}", f=(x -> _memory(Any[x])), N=1, mutable=true, typed=false, bitsonly=false), + (name="Memory{promote_type(X,Y)}", f=((x, y) -> _memory([x, y])), N=2, mutable=true, typed=:promoted, bitsonly=false), + (name="Memory{Any}", f=((x, y) -> _memory(Any[x, y])), N=2, mutable=true, typed=false, bitsonly=false), ) end @@ -260,8 +263,10 @@ function test_make_zero() end end @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue x = oneunit(T) w = wrapper.f(x) w_makez = make_zero(w) @@ -270,37 +275,44 @@ function test_make_zero() @test getx(w_makez) == zero(T) # correct value @test getx(w) === x # no mutation of original @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - @testset "doubly included in $(dualwrapper.name)" for - dualwrapper in filter(w -> (w.N == 2), wrappers) + @testset "doubly included in $(dualwrapper.name)" for dualwrapper in filter( + w -> (w.N == 2), wrappers + ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - d_outer_makez = make_zero(d_outer) - @test typeof(d_outer_makez) === typeof(d_outer) # correct type - @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type - @test typeof(getx(getx(d_outer_makez))) === T # correct type - @test getx(d_outer_makez) === gety(d_outer_makez) # correct topology - @test getx(getx(d_outer_makez)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # no mutation of original - @test getx(d_outer) === w_inner # no mutation of original - @test getx(w_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + d_outer_makez = make_zero(d_outer) + @test typeof(d_outer_makez) === typeof(d_outer) # correct type + @test typeof(getx(d_outer_makez)) === typeof(w_inner) # correct type + @test typeof(getx(getx(d_outer_makez))) === T # correct type + @test getx(d_outer_makez) === gety(d_outer_makez) # correct layout + @test getx(getx(d_outer_makez)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # no mutation of original + @test getx(d_outer) === w_inner # no mutation of original + @test getx(w_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - w_outer_makez = make_zero(w_outer) - @test typeof(w_outer_makez) === typeof(w_outer) # correct type - @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type - @test typeof(getx(getx(w_outer_makez))) === T # correct type - @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct topology - @test getx(getx(w_outer_makez)) == zero(T) # correct value - @test getx(w_outer) === d_inner # no mutation of original - @test getx(d_inner) === gety(d_inner) # no mutation of original - @test getx(d_inner) === x # no mutation of original - @test x == oneunit(T) # no mutation of original (relevant for BigFloat) - if wrapper.mutable && !dualwrapper.mutable + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + w_outer_makez = make_zero(w_outer) + @test typeof(w_outer_makez) === typeof(w_outer) # correct type + @test typeof(getx(w_outer_makez)) === typeof(d_inner) # correct type + @test typeof(getx(getx(w_outer_makez))) === T # correct type + @test getx(getx(w_outer_makez)) == gety(getx(w_outer_makez)) # correct layout + @test getx(getx(w_outer_makez)) == zero(T) # correct value + @test getx(w_outer) === d_inner # no mutation of original + @test getx(d_inner) === gety(d_inner) # no mutation of original + @test getx(d_inner) === x # no mutation of original + @test x == oneunit(T) # no mutation of original (relevant for BigFloat) + end + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly # some code paths can only be hit with three layers of wrapping: # mutable(immutable(mutable(scalar))) - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) w_inner = wrapper.f(x) d_middle = dualwrapper.f(w_inner, w_inner) w_outer = outerwrapper.f(d_middle) @@ -309,7 +321,7 @@ function test_make_zero() @test typeof(getx(w_outer_makez)) === typeof(d_middle) # correct type @test typeof(getx(getx(w_outer_makez))) === typeof(w_inner) # correct type @test typeof(getx(getx(getx(w_outer_makez)))) === T # correct type - @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct topology + @test getx(getx(w_outer_makez)) === gety(getx(w_outer_makez)) # correct layout @test getx(getx(getx(w_outer_makez))) == zero(T) # correct value @test getx(w_outer) === d_middle # no mutation of original @test getx(d_middle) === gety(d_middle) # no mutation of original @@ -324,54 +336,100 @@ function test_make_zero() @testset "inactive" begin @testset "in $(wrapper.name)" for wrapper in wrappers if wrapper.N == 1 - w = wrapper.f(inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === inactive # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), true), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test typeof(getx(w_makez)) === typeof(mixed) # correct type + @test getx(w_makez)[1] === 0.0 # correct value + @test getx(w_makez)[2] === inactive # preserved inactive identity + @test getx(w) === mixed # no mutation of original + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved inactive value + @test mixed[1] === 1.0 # no mutation of original + @test mixed[2] === inactivearr # no mutation of original + end + end end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === inactivearr # no mutation of original else # wrapper.N == 2 @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - w_makez = make_zero(w) - if wrapper.typed == true - @test w_makez === w # preserved wrapper identity if guaranteed const + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + w_makez = make_zero(w) + if wrapper.typed in (true, :promoted) + if w isa JLArray # needs JLArray activity + @test_broken w_makez === w + else + @test w_makez === w # preserved wrapper identity if guaranteed const + end + end + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === gety(w_makez) # preserved layout + @test getx(w_makez) === inactive # preserved identity + @test getx(w) === gety(w) # no mutation of original + @test getx(w) === inactive # no mutation of original + if inactive === inactive + @test inactivearr[1] === inactivetup # preserved value + end end - @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === gety(w_makez) # preserved topology - @test getx(w_makez) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value - @test getx(w) === gety(w) # no mutation of original - @test getx(w) === inactivearr # no mutation of original end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - w_makez = make_zero(w) - @test typeof(w_makez) === typeof(w) # correct type - @test typeof(getx(w_makez)) === typeof(a) # correct type - @test getx(w_makez) == [0.0] # correct value - @test gety(w_makez) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value - @test getx(w) === a # no mutation of original - @test a[1] === 1.0 # no mutation of original - @test gety(w) === inactivearr # no mutation of original - if wrapper.typed == :partial - # above: untyped active / typed inactive - # below: untyped inactive / typed active - w = wrapper.f(inactivearr, a) + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) w_makez = make_zero(w) @test typeof(w_makez) === typeof(w) # correct type - @test getx(w_makez) === inactivearr # preserved inactive identity + @test typeof(getx(w_makez)) === typeof(a) # correct type + @test getx(w_makez) == [0.0] # correct value + @test gety(w_makez) === inactivearr # preserved inactive identity @test inactivearr[1] === inactivetup # preserved inactive value - @test typeof(gety(w_makez)) === typeof(a) # correct type - @test gety(w_makez) == [0.0] # correct value - @test getx(w) === inactivearr # no mutation of original - @test gety(w) === a # no mutation of original + @test getx(w) === a # no mutation of original @test a[1] === 1.0 # no mutation of original + @test gety(w) === inactivearr # no mutation of original + if wrapper.typed == :partial + # above: untyped active / typed inactive + # below: untyped inactive / typed active + w = wrapper.f(inactivearr, a) + w_makez = make_zero(w) + @test typeof(w_makez) === typeof(w) # correct type + @test getx(w_makez) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + @test typeof(gety(w_makez)) === typeof(a) # correct type + @test gety(w_makez) == [0.0] # correct value + @test getx(w) === inactivearr # no mutation of original + @test gety(w) === a # no mutation of original + @test a[1] === 1.0 # no mutation of original + end end end end @@ -387,7 +445,7 @@ function test_make_zero() @test typeof(w_makez) === typeof(w) # correct type @test typeof(w_makez[1]) === typeof(a) # correct type @test w_makez[1] == [0.0] # correct value - @test w_makez[2] === w_makez[3] # correct topology (topology should propagate even when copy_if_inactive = Val(true)) + @test w_makez[2] === w_makez[3] # correct layout (layout should propagate even when copy_if_inactive = Val(true)) @test w[1] === a # no mutation of original @test a[1] === 1.0 # no mutation of original @test w[2] === w[3] # no mutation of original @@ -423,9 +481,26 @@ function test_make_zero() @test all(m.x == oneunit(m.x) for m in mwraps) # no mutation of original end end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a_makez = make_zero(a) + @test a_makez[1] === zero(a[1]) + @test a_makez[2] === zero(a[2]) + @test typeof(a_makez[3]) === btype + @test a_makez[3] == 0 + @test a_makez[4] === a_makez[3] + @test typeof(a_makez[5]) === btype + @test a_makez[5] == 0 + @test a_makez[5] !== a_makez[3] + end + end @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers ) a = [1.0] if wrapper.N == 1 @@ -475,6 +550,37 @@ function test_make_zero() @test v.data === a # no mutation of original @test a[1] === 1.0 # no mutation of original end + @testset "runtime inactive" begin + # verify that MutableWrapper is seen as active + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + a_makez = make_zero(a) + @test a_makez == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # verify that MutableWrapper is seen as inactive and shared/copied according to + # copy_if_inactive + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a) + @test a_makez == a # equal + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a, Val(false)) + @test a_makez === a # identical + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a, Val(true)) + @test a_makez !== a # not identical + @test a_makez == a # but equal + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active + @assert a.x === 1.0 # sanity check + a_makez = @invokelatest make_zero(a) + @test a_makez == MutableWrapper(0.0) + end @testset "undefined fields/unassigned elements" begin @testset "array w inactive/active/mutable/unassigned" begin a = [1.0] @@ -494,12 +600,22 @@ function test_make_zero() end @testset "struct w inactive/active/mutable/undefined" begin a = [1.0] - incomplete = Incomplete("a", 1.0, a) - incomplete_makez = make_zero(incomplete) - @test typeof(incomplete_makez) === typeof(incomplete) # correct type - @test typeof(incomplete_makez.w) === typeof(a) # correct type - @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined - @test a[1] === 1.0 # no mutation of original + @testset "single undefined" begin + incomplete = Incomplete("a", 1.0, a, nothing) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0], nothing) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end + @testset "multiple undefined" begin + incomplete = Incomplete("a", 1.0, a) + incomplete_makez = make_zero(incomplete) + @test typeof(incomplete_makez) === typeof(incomplete) # correct type + @test typeof(incomplete_makez.w) === typeof(a) # correct type + @test incomplete_makez == Incomplete("a", 0.0, [0.0]) # correct value, propagated undefined + @test a[1] === 1.0 # no mutation of original + end end @testset "mutable struct w inactive/const active/active/mutable/undefined" begin a = [1.0] @@ -524,8 +640,10 @@ end function test_make_zero!() @testset "nested types" begin - @testset "$T in $(wrapper.name)" for - T in scalartypes, wrapper in filter(w -> (w.N == 1), wrappers) + @testset "$T in $(wrapper.name)" for T in scalartypes, wrapper in filter( + w -> (w.N == 1), wrappers + ) + (!wrapper.bitsonly || isbitstype(T)) || continue x = oneunit(T) if wrapper.mutable w = wrapper.f(x) @@ -537,34 +655,39 @@ function test_make_zero!() @testset "doubly included in $(dualwrapper.name)" for dualwrapper in ( filter(w -> ((w.N == 2) && (w.mutable || wrapper.mutable)), wrappers) ) + (!dualwrapper.bitsonly || isbitstype(T)) || continue w_inner = wrapper.f(x) - d_outer = dualwrapper.f(w_inner, w_inner) - make_zero!(d_outer) - @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type - @test typeof(getx(getx(d_outer))) === T # preserved type - @test getx(getx(d_outer)) == zero(T) # correct value - @test getx(d_outer) === gety(d_outer) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if wrapper.mutable - @test getx(d_outer) === w_inner # preserved identity + if !dualwrapper.bitsonly || isbits(w_inner) + d_outer = dualwrapper.f(w_inner, w_inner) + make_zero!(d_outer) + @test typeof(getx(d_outer)) === typeof(w_inner) # preserved type + @test typeof(getx(getx(d_outer))) === T # preserved type + @test getx(getx(d_outer)) == zero(T) # correct value + @test getx(d_outer) === gety(d_outer) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if wrapper.mutable + @test getx(d_outer) === w_inner # preserved identity + end end d_inner = dualwrapper.f(x, x) - w_outer = wrapper.f(d_inner) - make_zero!(w_outer) - @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type - @test typeof(getx(getx(w_outer))) === T # preserved type - @test getx(getx(w_outer)) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology - @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) - if dualwrapper.mutable - @test getx(w_outer) === d_inner # preserved identity + if !wrapper.bitsonly || isbits(d_inner) + w_outer = wrapper.f(d_inner) + make_zero!(w_outer) + @test typeof(getx(w_outer)) === typeof(d_inner) # preserved type + @test typeof(getx(getx(w_outer))) === T # preserved type + @test getx(getx(w_outer)) == zero(T) # correct value + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout + @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) + if dualwrapper.mutable + @test getx(w_outer) === d_inner # preserved identity + end end - if wrapper.mutable && !dualwrapper.mutable + if wrapper.mutable && !dualwrapper.mutable && !dualwrapper.bitsonly # some code paths can only be hit with three layers of wrapping: # mutable(immutable(mutable(scalar))) - @assert !dualwrapper.mutable # sanity check - @testset "all wrapped in $(outerwrapper.name)" for - outerwrapper in filter(w -> ((w.N == 1) && w.mutable), wrappers) + @testset "all wrapped in $(outerwrapper.name)" for outerwrapper in filter( + w -> ((w.N == 1) && w.mutable && !w.bitsonly), wrappers + ) w_inner = wrapper.f(x) d_middle = dualwrapper.f(w_inner, w_inner) w_outer = outerwrapper.f(d_middle) @@ -573,7 +696,7 @@ function test_make_zero!() @test typeof(getx(getx(w_outer))) === typeof(w_inner) # preserved type @test typeof(getx(getx(getx(w_outer)))) === T # preserved type @test getx(getx(getx(w_outer))) == zero(T) # correct value - @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved topology + @test getx(getx(w_outer)) === gety(getx(w_outer)) # preserved layout @test getx(getx(w_outer)) === w_inner # preserved identity @test x == oneunit(T) # no mutation of scalar (relevant for BigFloat) end @@ -582,29 +705,64 @@ function test_make_zero!() end end @testset "inactive" begin - @testset "in $(wrapper.name)" for - wrapper in filter(w -> (w.mutable || (w.typed == true)), wrappers) + @testset "in $(wrapper.name)" for wrapper in filter( + w -> (w.mutable || (w.typed == true)), wrappers + ) if wrapper.N == 1 - w = wrapper.f(inactivearr) - make_zero!(w) - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive) + make_zero!(w) + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end + @testset "mixed" begin + for (inactive, mixed, condition) in [ + (inactivebits, (1.0, inactivebits), wrapper.mutable), + (inactivearr, [1.0, inactivearr], !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(mixed) + make_zero!(w) + @test getx(w)[1] === 0.0 + @test getx(w)[2] === inactive + if inactive === inactivearr + @test getx(w) === mixed # preserved identity + @test inactivearr[1] === inactivetup # preserved value + end + end + end else # wrapper.N == 2 @testset "multiple references" begin - w = wrapper.f(inactivearr, inactivearr) - make_zero!(w) - @test getx(w) === gety(w) # preserved topology - @test getx(w) === inactivearr # preserved identity - @test inactivearr[1] === inactivetup # preserved value + for (inactive, condition) in [ + (inactivebits, true), + (inactivearr, !wrapper.bitsonly), + ] + condition || continue + w = wrapper.f(inactive, inactive) + make_zero!(w) + @test getx(w) === gety(w) # preserved layout + @test getx(w) === inactive # preserved identity + if inactive === inactivearr + @test inactivearr[1] === inactivetup # preserved value + end + end end - @testset "alongside active" begin - a = [1.0] - w = wrapper.f(a, inactivearr) - make_zero!(w) - @test getx(w) === a # preserved identity - @test a[1] === 0.0 # correct value - @test gety(w) === inactivearr # preserved inactive identity - @test inactivearr[1] === inactivetup # preserved inactive value + if !wrapper.bitsonly + @testset "mixed" begin + a = [1.0] + w = wrapper.f(a, inactivearr) + make_zero!(w) + @test getx(w) === a # preserved identity + @test a[1] === 0.0 # correct value + @test gety(w) === inactivearr # preserved inactive identity + @test inactivearr[1] === inactivetup # preserved inactive value + end end end end @@ -625,9 +783,27 @@ function test_make_zero!() @test c == cz # correct value end end + @testset "heterogeneous float arrays" begin + b1r, b2r = big"1.0", big"2.0" + b1i, b2i = big"1.0" * im, big"2.0" * im + ar = AbstractFloat[1.0f0, 1.0, b1r, b1r, b2r] + ai = Complex{<:AbstractFloat}[1.0f0im, 1.0im, b1i, b1i, b2i] + for (a, btype) in [(ar, typeof(b1r)), (ai, typeof(b1i))] + a1, a2 = a[1], a[2] + make_zero!(a) + @test a[1] === zero(a1) + @test a[2] === zero(a2) + @test typeof(a[3]) === btype + @test a[3] == 0 + @test a[4] === a[3] + @test typeof(a[5]) === btype + @test a[5] == 0 + @test a[5] !== a[3] + end + end @testset "circular references" begin - @testset "$(wrapper.name)" for wrapper in ( - filter(w -> (w.mutable && (w.typed in (:partial, false))), wrappers) + @testset "$(wrapper.name)" for wrapper in filter( + w -> (w.mutable && (w.typed in (:partial, false))), wrappers ) a = [1.0] if wrapper.N == 1 @@ -648,22 +824,46 @@ function test_make_zero!() @test a[1] === 0.0 # correct value end end - @testset "bring your own IdSet" begin + @testset "bring your own IdDict" begin a = [1.0] - seen = Base.IdSet() + seen = IdDict() make_zero!(a, seen) - @test a[1] === 0.0 # correct value - @test (a in seen) # object added to IdSet + @test a[1] === 0.0 # correct value + @test haskey(seen, a) # object added to IdDict + @test seen[a] === a # object points to zeroed value, i.e., itself end @testset "custom leaf type" begin a = [1.0] v = CustomVector(a) - # bringing own IdSet to avoid calling the custom method directly; + # bringing own IdDict to avoid calling the custom method directly; # it should still be invoked - @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, Base.IdSet()) + @test_logs (:info, "make_zero!(::CustomVector)") make_zero!(v, IdDict()) @test v.data === a # preserved identity @test a[1] === 0.0 # correct value end + @testset "runtime inactive" begin + # verify that MutableWrapper is seen as active + a = MutableWrapper(1.0) + @assert !EnzymeRules.inactive_type(typeof(a)) + make_zero!(a) + @test a == MutableWrapper(0.0) + + # mark MutableWrapper as inactive + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = true + + # verify that MutableWrapper is seen as inactive + a.x = 1.0 + @invokelatest make_zero!(a) + @test a == MutableWrapper(1.0) + + # mark MutableWrapper as active again + @eval @inline EnzymeRules.inactive_type(::Type{<:MutableWrapper}) = false + + # verify that MutableWrapper is seen as active + a.x = 1.0 + @invokelatest make_zero!(a) + @test a == MutableWrapper(0.0) + end @testset "undefined fields/unassigned elements" begin @testset "array w inactive/active/mutable/unassigned" begin a = [1.0] @@ -693,7 +893,7 @@ function test_make_zero!() @test incomplete.w === a # preserved identity end @testset "Array{Tuple{struct w undefined}} (issue #1935)" begin - # old implementation triggered #1935 + # old implementation of make_zero! triggered #1935 # new implementation would work regardless due to limited use of justActive a = [1.0] incomplete = Incomplete("a", 1.0, a) @@ -719,7 +919,7 @@ function test_make_zero!() return nothing end -@testset "make_zero" test_make_zero() -@testset "make_zero!" test_make_zero!() +end # module RecursiveMapTests -end # module MakeZeroTests +@testset "make_zero" RecursiveMapTests.test_make_zero() +@testset "make_zero!" RecursiveMapTests.test_make_zero!() diff --git a/test/runtests.jl b/test/runtests.jl index 5c5d70d9fa..dc8e583788 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,7 +75,7 @@ include("abi.jl") include("typetree.jl") include("passes.jl") include("optimize.jl") -include("make_zero.jl") +include("recursive_maps.jl") include("rules.jl") include("rrules.jl") @@ -440,6 +440,25 @@ make3() = (1.0, 2.0, 3.0) da = [2.7] @test autodiff(Forward, sumdeepcopy, Duplicated(a, da))[1] ≈ 2.7 + # Nested containers to test nontrivial recursion in deepcopy reverse rule + b = [[3.14]] + db = [[0.0]] + sumdeepcopy_nested(x) = sum(sum, deepcopy(x)) + autodiff(Reverse, sumdeepcopy_nested, Duplicated(b, db)) + @test db[1][1] ≈ 1.0 + + c_inner = [3.14] + dc_inner = [0.0] + c = [c_inner, c_inner] + dc = [dc_inner, dc_inner] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(c, dc)) + @test dc[1] === dc[2] + @test dc[1][1] ≈ 2.0 + + d = [(3.14,)] + dd = [(0.0,)] + autodiff(Reverse, sumdeepcopy_nested, Duplicated(d, dd)) + @test dd[1][1] ≈ 1.0 end @testset "Deferred and deferred thunk" begin @@ -533,94 +552,70 @@ end @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active, Active(z)) @test_throws MethodError autodiff(ReverseHolomorphic, mul3, Active{Complex}, Active(z)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sum, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 1.0 + function reverse_holomorphic_array_tests( + f, val, dval_expected; val_expected=val, ret=Active, mapf=true + ) + vals = ComplexF64[val] + dvals = ComplexF64[zero(val)] + autodiff(ReverseHolomorphic, f, ret, Duplicated(vals, dvals)) + @test vals[1] ≈ val_expected + @test dvals[1] ≈ dval_expected - sumsq(x) = sum(x .* x) + # Use tuple to test out-of-place accumulate_seen! base case + tvals = [(ComplexF64(val),)] + dtvals = [(ComplexF64(zero(val)),)] + ft = mapf ? v -> first(map(f, v)) : f + autodiff(ReverseHolomorphic, ft, ret, Duplicated(tvals, dtvals)) + @test tvals[1][1] ≈ val_expected + @test dtvals[1][1] ≈ dval_expected + end - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sum" reverse_holomorphic_array_tests(sum, 3.4 + 2.7im, 1.0) + + sumsq(x) = sum(x .* x) + @testset "sumsq" reverse_holomorphic_array_tests(sumsq, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2(x) = sum(abs2.(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2" reverse_holomorphic_array_tests(sumsq2, 3.4 + 2.7im, 2(3.4 + 2.7im)) sumsq2C(x) = Complex{Float64}(sum(abs2.(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq2C, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3(x) = sum(x .* conj(x)) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 3.4 - 2.7im - - sumsq3R(x) = Float64(sum(x .* conj(x))) - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, sumsq3R, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 3.4 + 2.7im - @test dvals[1] ≈ 2 * (3.4 + 2.7im) + @testset "sumsq2C" reverse_holomorphic_array_tests(sumsq2C, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3(x) = sum(x .* conj.(x)) + @testset "sumsq3" reverse_holomorphic_array_tests(sumsq3, 3.4 + 2.7im, 3.4 - 2.7im) + + sumsq3R(x) = Float64(sum(x .* conj.(x))) + @testset "sumsq3R" reverse_holomorphic_array_tests(sumsq3R, 3.4 + 2.7im, 2(3.4 + 2.7im)) function setinact(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] nothing end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact" reverse_holomorphic_array_tests( + setinact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) function setinact2(z) - z[1] *= 2 + z[1] = 2 .* z[1] # works for both [x] and [(x,)] return 0.0+1.0im end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setinact2, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - + @testset "setinact2 Const" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) + @testset "setinact2 Active" reverse_holomorphic_array_tests( + setinact2, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + ) function setact(z) - z[1] *= 2 - return z[1] + z[1] = 2 .* z[1] # works for both [x] and [(x,)] + return z[1][1] # returns scalar for both [x] and [(x,)] end - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Const, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 0.0 - - vals = Complex{Float64}[3.4 + 2.7im] - dvals = Complex{Float64}[0.0] - autodiff(ReverseHolomorphic, setact, Active, Duplicated(vals, dvals)) - @test vals[1] ≈ 2 * (3.4 + 2.7im) - @test dvals[1] ≈ 2.0 + @testset "setact Const" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 0.0; val_expected=2(3.4 + 2.7im), ret=Const, mapf=false + ) + @testset "setact Active" reverse_holomorphic_array_tests( + setact, 3.4 + 2.7im, 2.0; val_expected=2(3.4 + 2.7im), ret=Active, mapf=false + ) function upgrade(z) z = ComplexF64(z)