Skip to content

Commit

Permalink
Add recursive_map and base make_zero(!) on it
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Oct 31, 2024
1 parent 2c8a581 commit 72bda99
Show file tree
Hide file tree
Showing 8 changed files with 1,277 additions and 565 deletions.
14 changes: 9 additions & 5 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:SArray}
return Base.zero(x)
end
@inline function Enzyme.EnzymeCore.make_zero(x::FT)::FT where {FT<:MArray}
return Base.zero(x)
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
# but in case their dedicated `zero` and `fill!` methods are more efficient than
# `make_zero(!)`s generic recursion, we opt into treating them as leaves when they have
# isbits eltypes (non-isbits eltypes excluded as the dedicated `zero` and `fill!` methods
# don't support those).
@inline function Enzyme.EnzymeCore.isvectortype(
::Type{<:Union{SArray{S,T},MArray{S,T}}}
) where {S,T}
return isbitstype(T) && Enzyme.Compiler.RecursiveMap.isscalartype(T)
end

end
82 changes: 75 additions & 7 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,28 +501,96 @@ function autodiff_thunk end
function autodiff_deferred_thunk end

"""
make_zero(prev::T, ::Val{copy_if_inactive}=Val(false))::T
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T
Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
Extending this method for custom types is rarely needed. For new plain array types like GPU
arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type implements
`Base.zero`.
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
make_zero!(val::T, seen::IdDict=IdDict())::Nothing
Recursively set a variables differentiable fields to zero. Only applicable for types `T`
that are mutable or hold all differentiable values in mutable containers (e.g.,
`Tuple{Vector{Float64}}`).
Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
Extending this method for custom types is rarely needed. For new plain mutable array types
like GPU arrays, extending [`isvectortype`](@ref) is sufficient as long as the array type
implements `Base.zero` and `Base.fill!`.
"""
function make_zero! end

"""
make_zero(prev::T)
isvectortype(::Type{T})::Bool
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
and [`make_zero!`](@ref) recurse through an object.
By default, `isvectortype(T) == true` for `T` such that `isscalartype(T) == true` or
`T <: Union{Array{U},GenericMemory{_,U}}` where `isscalartype(U) == true`.
A new plain array type, for example a GPU array, may extend this as follows:
```julia
@inline EnzymeCore.isvectortype(::Type{<:GPUArray{U}}) where {U} = isscalartype(U)
```
Such a type should implement `Base.zero` and, if mutable, `Base.fill!`. (If this is not
feasible, an alternative is to add methods `EnzymeCore.make_zero(arr::T)::T` and, if
mutable, `EnzymeCore.make_zero!(arr::T)::Nothing`; such methods will also be picked up by
recursive calls.)
Helper function to recursively make zero.
Such extensions are mostly relevant for the lowest-level of abstraction of memory at which
vector space operations like addition and scalar multiplication are supported, the
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
example, structured matrix wrappers and sparse array types that are backed by `Array` should
not extend `isvectortype`.
See also [`isscalartype`](@ref).
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
end
function isvectortype end

"""
isscalartype(::Type{T})::Bool
Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]
By default, `isscalartype(T) == true` for `T <: AbstractFloat` and
`T <: Complex{<:AbstractFloat}`.
A hypothetical new real number type with Enzyme support should in most cases simply subtype
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
the function can be extended as follows:
```julia
@inline EnzymeCore.isscalartype(::Type{<:NewReal}) = true
@inline EnzymeCore.isscalartype(::Type{<:Complex{<:NewReal}}) = true
```
In either case, the type should implement `Base.zero`. (If this is not feasible, an
alternative is to add a method `EnzymeCore.make_zero(x::T)::T`; such a method will also be
picked up by recursive calls.)
See also [`isvectortype`](@ref).
[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
mentioned here only to illustrate that it would be inappropriate to use traits like
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing,
demonstrating the need for a dedicated `isscalartype` trait.
"""
function isscalartype end

function tape_type end

Expand Down
3 changes: 1 addition & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1132,8 +1132,6 @@ struct Tape{TapeTy,ShadowTy,ResT}
shadow_return::ShadowTy
end

include("make_zero.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, world)
funcspec = my_methodinstance(typeof(f), tt, world)
nested_codegen!(mode, mod, funcspec, world)
Expand Down Expand Up @@ -7610,6 +7608,7 @@ end
end

# Recursively return x + f(y), where y is active, otherwise x
include("recursive_map.jl")

@inline function recursive_add(
x::T,
Expand Down
Loading

0 comments on commit 72bda99

Please sign in to comment.