Skip to content

Commit

Permalink
Add more Duplicated methods for Enzyme.jl support (#2471)
Browse files Browse the repository at this point in the history
* add more Duplicated methods

* update macro to zero, show

* make informative errors if you use Duplicated without loading Enzyme

* note on macro

* fix some tests

* add an Enzyme docs page

* tweaks & tests

* typos

* news, docs

* let Flux own the function update! to avoid piracy

* Revert "let Flux own the function update! to avoid piracy"

This reverts commit ca5a20f.

* demand Optimisers PR

* fixup

* force depwarns

* allow aux in withgradient

* disallow Active

* disallow trivial Duplicated

* don't use ReverseWithPrimal in gradient

* tweak

* giant post-rebase fixup after everything was moved around... all earlier commits are a mess now, probably

* clean up more rebase mess

* fix docs

* try out Ref for withgradient

* don't own `_make_zero!`

* add explicit errors for 2nd order

* more rebase problems

Co-authored-by: Carlo Lucibello <[email protected]>

* teach Flux.state about Duplicated

* another explicit error for Zygote mistake

* ahem

* don't use Enzyme's make_zero!, fix some bugs

* maybe this works?

* see if CI likes these

* turns out train! does have tests

* enzyme tests

* fix tests?

* minor comments

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
mcabbott and CarloLucibello authored Dec 3, 2024
1 parent cb76e9d commit f5d25e5
Show file tree
Hide file tree
Showing 15 changed files with 633 additions and 36 deletions.
6 changes: 4 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.

## v0.15.0
## v0.15.0 (December 2024)
* Recurrent layers have undergone a complete redesign in [PR 2500](https://github.com/FluxML/Flux.jl/pull/2500).
* `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(x_t, h_t) -> h_{t+1}`.
* `RNN`, `LSTM`, and `GRU` no longer store the hidden state internally, it has to be explicitely passed to the layer. Moreover, they now process entire sequences at once, rather than one element at a time: `rnn(x, h) -> h′`.
Expand All @@ -12,6 +12,8 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change.
The module is still available for now, but will be removed in a future release.
* Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible.
* Further support for Enzyme.jl, via methods of `Flux.gradient(loss, Duplicated(model))`.
Flux now owns & exports `gradient`, but without `Duplicated` this still defaults to calling Zygote.jl.
* `Flux.params` has been deprecated. Use Zygote's explicit differentiation instead,
`gradient(m -> loss(m, x, y), model)`, or use `Flux.trainables(model)` to get the trainable parameters.
* Flux now requires Functors.jl v0.5. This new release of Functors assumes all types to be functors by default. Therefore, applying `@layer` or `@functor` to a type is no longer strictly necessary for Flux's models. However, it is still recommended to use `@layer Model` for additional functionality like pretty printing.
Expand Down Expand Up @@ -40,7 +42,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
* After a deprecations cycle, the macro `@epochs` and the functions `Flux.stop`, `Flux.skip`, `Flux.zeros`, `Flux.ones` have been removed.

## v0.13.17
* Apple's Metal GPU acceleration preliminary support via the extension mechanism.
* Apple's Metal GPU acceleration preliminary support via the extension mechanism.

## v0.13.16
* Most greek-letter keyword arguments are deprecated in favour of ascii.
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.15.0-DEV"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Expand Down Expand Up @@ -48,14 +49,15 @@ ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.13"
Functors = "0.5"
EnzymeCore = "0.7.7, 0.8.4"
MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Optimisers = "0.4"
Optimisers = "0.4.1"
Preferences = "1"
ProgressLogging = "0.1"
Reexport = "1.0"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ makedocs(
"Flat vs. Nested" => "reference/destructure.md",
"Callback Helpers" => "reference/training/callbacks.md",
"Gradients -- Zygote.jl" => "reference/training/zygote.md",
"Gradients -- Enzyme.jl" => "reference/training/enzyme.md",
"Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md",
"Batching Data -- MLUtils.jl" => "reference/data/mlutils.md",
"OneHotArrays.jl" => "reference/data/onehot.md",
Expand Down
95 changes: 95 additions & 0 deletions docs/src/reference/training/enzyme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

# [Automatic Differentiation using Enzyme.jl](@id autodiff-enzyme)

[Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is a new package for automatic differentiation.
Like Zygote.jl, calling `gradient(f, x)` causes it to hooks into the compiler and transform code that is executed while calculating `f(x)`, in order to produce code for `∂f/∂x`.
But it does so much later in the optimisation process (on LLVM instead of Julia's untyped IR) which you can [read about here](https://proceedings.nips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf)].
It needs far fewer custom rules than Zygote/ChainRules, and in particular is able to support mutation of arrays.

Flux now builds in support for this, using Enzyme's own `Duplicated` type.
Calling `Duplicated` on any Flux model which was defined using `@layer` will allocate space for the gradient,
and passing that to `gradient` (or `withgradient`, or `train!`) will then use Enzyme instead of Zygote.
The gradient functions still return the gradient as usual, which can then be passed to `update!`:

```julia
julia> using Flux, Enzyme

julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax); # from model zoo

julia> dup_model = Enzyme.Duplicated(model) # this allocates space for the gradient
Duplicated(
Chain(
Dense(784 => 32, σ), # 25_120 parameters
Dense(32 => 10), # 330 parameters
NNlib.softmax,
),
# norm(∇) ≈ 0.0f0
) # Total: 4 arrays, 25_450 parameters, 199.391 KiB.

julia> x1 = randn32(28*28, 1); # fake image

julia> y1 = [i==3 for i in 0:9]; # fake label

julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const(x1), Const(y1)) # uses Enzyme
((layers = ((weight = Float32[-0.010354728 0.032972857
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
```

The gradient returned here is also stored within `dup_model`.
Both share the same arrays -- what is returned is not a copy, just a view of the same memory (wrapped in `NamedTuple`s instead of `struct`s).
They will all be set to zero when you call `gradient` again, then replaced with the new values.
Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored.

Writing `Const(x1)` is optional, just plain `x1` is implicitly constant.
Any set of `Duplicated` and `Const` arguments may appear in any order, so long as there is at least one `Duplicated`.

The gradient `grads_f[1]` can be passed to `update!` as usual.
But for convenience, you may also use what is stored within `Duplicated`.
These are equivalent ways to perform an update step:

```julia
julia> opt_state = Flux.setup(Adam(), model)

julia> ans == Flux.setup(Adam(), dup_model)

julia> Flux.update!(opt_state, model, grads_f[1]) # exactly as for Zygote gradients

julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only
```

Instead of using these FLux functions, you can also use Enzyme's own functions directly.
`Enzyme.gradient` works like this:

```julia
julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1))
(Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing)

julia> grads_f[1].layers[2].bias grads_e[1].layers[2].bias
true
```

Note that what `Enzyme.gradient` returns is an object like `deepcopy(model)` of the same type, `grads_e[1] isa Chain`.
But its fields contain the same gradient.

There is also a method of `train!` which similarly takes `Duplicated(model)`:

```julia
julia> opt_state = Flux.setup(Adam(0), model);

julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state)
```

## Second-order AD

If you calculate a gradient within the loss function, then training will involve 2nd derivatives.
While this is in principle supported by Zygote.jl, there are many bugs, and Enzyme.jl is probably a better choice.

## Listing

```@docs
Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
```

Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/).
4 changes: 3 additions & 1 deletion docs/src/reference/training/zygote.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ CollapsedDocStrings = true

# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote)

Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).
Flux's `gradient` function uses [Zygote](https://github.com/FluxML/Zygote.jl) by default, and also uses this function within [`train!`](@ref Flux.train!) to differentiate the model.
Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/).

Flux also has support for Enzyme.jl, documented [on its own page](@ref autodiff-enzyme).

## Explicit style

Expand Down
109 changes: 97 additions & 12 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,117 @@
module FluxEnzymeExt

using Flux
import Flux.Train: train!, _rule_to_state
import Flux.Train: _enzyme_train!

import Optimisers
import Functors
import Enzyme
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal, DuplicatedNoNeed
using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
using ProgressLogging: @withprogress, @logprogress

_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)
EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true

_applyloss(loss, model, d...) = loss(model, d...)
### gradient & withgradient

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields
# After https://github.com/EnzymeAD/Enzyme.jl/pull/1961 probably this can be `make_zero!(Ref(dup.dval))`
_make_zero!(model) = Functors.fmapstructure(_make_zero_inner!, model)
function _make_zero_inner!(x::AbstractArray{<:Number})
Optimisers.isnumeric(x) || return
Optimisers.maywrite(x) || error("can't handle this")
fill!(x, zero(eltype(x)))
nothing
end
_make_zero_inner!(x) = nothing # any other Functors leaf type

#= # This _make_zero! matches what Flux allows elsewhere:
julia> Flux.setup(Adam(), (1:3.)')
ERROR: model must be fully mutable for `train!` to work, got `x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}) = true`
=#
# Perhaps canonical way for Enzyme is more like this:
# function _make_zero!(x::AbstractArray{<:Number})
# if Enzyme.guess_activity(typeof(x), Reverse) <: Duplicated
# fill!(x, zero(eltype(x)))
# elseif Enzyme.guess_activity(typeof(x), Reverse) <: Const
# # that's OK
# else
# error("not sure what it should do for Active?")
# end
# end

function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
_check_mutable(x)
end
Enzyme.autodiff(Reverse, Const(f), Active, args...)
map(_grad_or_nothing, args)
end

function train!(loss, model::Duplicated, data, opt; cb = nothing)
_check_mutable(x::Const) = nothing
_check_mutable(x::Duplicated) = Functors.anymutable(x) || error(
"""`Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays."""
)

# This function strips the returned gradient to be Zygote-like:
_grad_or_nothing(dup::Duplicated) = Flux.fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = Optimisers.isnumeric(x) ? x : nothing

function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::Bool=true)
for x in args
zero && x isa Duplicated && _make_zero!(x.dval)
_check_mutable(x)
end

# Take I, doesn't allow for aux at all.
# _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)

# Take II, using split mode.
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)

# Take III, it may be more efficient to have the function write the loss into Ref(0.0)?
# dup_loss = DuplicatedNoNeed(Ref(0f0), Ref(1f0))
# # result = autodiff(Reverse, Const(_ref_loss!), Const, dup_loss, Const(f), args...)
# _, result = autodiff(ReverseWithPrimal, Const(_ref_loss!), Const, dup_loss, Const(f), args...)

(; val = result, grad = map(_grad_or_nothing, args))
end

@inline _sensitivity(y::Real) = one(y)
@inline _sensitivity(ys::Tuple{Real,Vararg}) = (one(ys[1]), Enzyme.make_zero(Base.tail(ys))...)
@inline _sensitivity(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = NamedTuple{S}(_sensitivity(Tuple(ys)))
_sensitivity(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")

function _ref_loss!(out::Ref, f, args...) # for Take III above
val = f(args...)
out[] = _get_loss(val) # saves loss by mutation
val # returns the whole thing
end

@inline _get_loss(y::Real) = y
@inline _get_loss(ys::Tuple{Real,Vararg}) = ys[1]
@inline _get_loss(ys::NamedTuple{S, <:Tuple{Real,Vararg}}) where S = ys[1]
_get_loss(y) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")

### Flux.Train, for train!

_applyloss(loss, model, d...) = loss(model, d...)

function _enzyme_train!(loss, model::Duplicated, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)

_make_zero!(model.dval)
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
Active, Const(loss), model, map(Const, d_splat)...)

if !isfinite(l)
Expand All @@ -39,4 +124,4 @@ function train!(loss, model::Duplicated, data, opt; cb = nothing)
end
end

end # FluxEnzymeExt
end # FluxEnzymeExt
16 changes: 10 additions & 6 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@ using MacroTools: @forward
@reexport using NNlib
using NNlib: conv, ∇conv_data, depthwiseconv, output_size
using MLUtils
using Adapt, OneHotArrays
using Functors: Functors, fmap, fmapstructure

using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
import Optimisers: trainable
@reexport using Optimisers

using Random: default_rng

using Zygote, ChainRulesCore
using Zygote: @adjoint, gradient, pullback
using Zygote: @adjoint, pullback
using Zygote.ForwardDiff: value
export gradient
using EnzymeCore: EnzymeCore

@reexport using MLDataDevices: MLDataDevices, supported_gpu_backends, reset_gpu_device!,
default_device_rng,
Expand Down Expand Up @@ -53,11 +56,12 @@ export Chain, Dense, Embedding, EmbeddingBag,
# utils
outputsize, state, create_bias, @layer,
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
onehot, onehotbatch, onecold,
# from Train
setup, train!,
# from Optimsers.jl
destructure, freeze!, thaw!, adjust!, trainables, update!, trainable,
withgradient,
# init
glorot_uniform,
glorot_normal,
Expand Down Expand Up @@ -89,13 +93,13 @@ export Chain, Dense, Embedding, EmbeddingBag,
tversky_loss,
))

include("gradient.jl")
export gradient

include("train.jl")
using .Train
using .Train: setup

using Adapt, OneHotArrays
using Functors: Functors, fmap, fmapstructure

include("utils.jl")
include("functor.jl")

Expand Down
Loading

0 comments on commit f5d25e5

Please sign in to comment.