-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attach rule to mapfoldl_impl
not foldl
#569
base: main
Are you sure you want to change the base?
Changes from all commits
63d7b14
78b8230
6ca4726
e17b0f1
c698574
1961aca
1b8e121
7b28e10
dd4f952
f18f61d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -417,137 +417,188 @@ end | |||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
##### | ||||||||||||||||||||||
##### `foldl` | ||||||||||||||||||||||
##### `mapfoldl(f, g, ::Tuple)` | ||||||||||||||||||||||
##### | ||||||||||||||||||||||
|
||||||||||||||||||||||
using Base: mapfoldl_impl | ||||||||||||||||||||||
|
||||||||||||||||||||||
# For tuples there should be no harm in handling `map` first. | ||||||||||||||||||||||
# This will also catch `mapreduce`. | ||||||||||||||||||||||
|
||||||||||||||||||||||
function rrule( | ||||||||||||||||||||||
cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple; | ||||||||||||||||||||||
) where {F,G} | ||||||||||||||||||||||
y, backmap = rrule(cfg, map, f, x) | ||||||||||||||||||||||
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y) | ||||||||||||||||||||||
function mapfoldl_pullback_tuple(dz) | ||||||||||||||||||||||
_, _, dop, dinit, dy = backred(dz) | ||||||||||||||||||||||
_, df, dx = backmap(dy) | ||||||||||||||||||||||
return (NoTangent(), df, dop, dinit, dx) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return z, mapfoldl_pullback_tuple | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
##### | ||||||||||||||||||||||
##### `foldl(f, ::Tuple)` | ||||||||||||||||||||||
##### | ||||||||||||||||||||||
|
||||||||||||||||||||||
# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when | ||||||||||||||||||||||
# this `f` is stateful, in which case the gradient must be calculated in the reverse order. | ||||||||||||||||||||||
# this `f` is stateful, in which case the gradient must be calculated in the reverse order. | ||||||||||||||||||||||
|
||||||||||||||||||||||
# The implementation aims to be efficient for both tuples and arrays, although using accumulate | ||||||||||||||||||||||
# to carry intermediate results along creates arrays of tuples which could be avoided; using a | ||||||||||||||||||||||
# loop can be a few times faster. Note also that it does not return a gradient for `init`. | ||||||||||||||||||||||
# The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument, | ||||||||||||||||||||||
# which is handled below. For tuples, `reduce` also comes here. | ||||||||||||||||||||||
|
||||||||||||||||||||||
function rrule( | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, ::typeof(foldl), op::G, x::Union{AbstractArray, Tuple}; | ||||||||||||||||||||||
init=_InitialValue() | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, | ||||||||||||||||||||||
::typeof(Base.mapfoldl_impl), | ||||||||||||||||||||||
::typeof(identity), | ||||||||||||||||||||||
op::G, | ||||||||||||||||||||||
init::Base._InitialValue, | ||||||||||||||||||||||
x::Tuple; | ||||||||||||||||||||||
) where {G} | ||||||||||||||||||||||
list, start = if init === _InitialValue() | ||||||||||||||||||||||
_drop1(x), first(x) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
# Case with init keyword is simpler to understand first! | ||||||||||||||||||||||
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b | ||||||||||||||||||||||
hobbits = accumulate(Base.tail(x); init=(first(x), nothing)) do (a, _), b | ||||||||||||||||||||||
# Here `a` is what we would normally cary forward, and `_` ignores | ||||||||||||||||||||||
# the previous iteration's pullback function (needed later), | ||||||||||||||||||||||
# while `b` is the fresh input from `list` as usual. | ||||||||||||||||||||||
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here! | ||||||||||||||||||||||
c, back = rrule_via_ad(config, op, a, b) | ||||||||||||||||||||||
# We don't really need to store every `c`, last one is `foldl` output. | ||||||||||||||||||||||
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
y = first(last(hobbits)) | ||||||||||||||||||||||
axe = axes(x) | ||||||||||||||||||||||
project = ProjectTo(x) | ||||||||||||||||||||||
function unfoldl(dy) | ||||||||||||||||||||||
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) | ||||||||||||||||||||||
function foldl_pullback_tuple(dy) | ||||||||||||||||||||||
trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) | ||||||||||||||||||||||
ds, da, db = back(dc) | ||||||||||||||||||||||
# Don't need to store every `da`, need one for the next iteration + maybe last | ||||||||||||||||||||||
# Don't need to store every `da`, need one for the next iteration + the last. | ||||||||||||||||||||||
end | ||||||||||||||||||||||
dop = sum(first, trio) | ||||||||||||||||||||||
dx = map(last, _reverse1(trio)) | ||||||||||||||||||||||
if init === _InitialValue() | ||||||||||||||||||||||
# `hobbits` is one short | ||||||||||||||||||||||
dx = _vcat1(trio[end][2], dx) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return (NoTangent(), dop, project(_reshape1(dx, axe))) | ||||||||||||||||||||||
dx = (trio[end][2], reverse(map(last, trio))...) | ||||||||||||||||||||||
return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx)) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return y, unfoldl | ||||||||||||||||||||||
return y, foldl_pullback_tuple | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
function rrule( | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, | ||||||||||||||||||||||
::typeof(Base.mapfoldl_impl), | ||||||||||||||||||||||
::typeof(identity), | ||||||||||||||||||||||
op::G, | ||||||||||||||||||||||
init, | ||||||||||||||||||||||
x::Tuple; | ||||||||||||||||||||||
) where {G} | ||||||||||||||||||||||
# Trivial case handled here to avoid ambiguities (and necc. because of Base.tail below) | ||||||||||||||||||||||
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) | ||||||||||||||||||||||
isempty(x) && return init, foldl_pullback_empty | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Treat `init` by simply appending it to the `x`: | ||||||||||||||||||||||
y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...)) | ||||||||||||||||||||||
project_x = ProjectTo(x) | ||||||||||||||||||||||
project_in = ProjectTo(init) | ||||||||||||||||||||||
function foldl_pullback_tuple_init(dy) | ||||||||||||||||||||||
_, _, dop, _, dxplus = back(dy) | ||||||||||||||||||||||
return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base.tail(dxplus))) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return y, foldl_pullback_tuple_init | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
##### | ||||||||||||||||||||||
##### Iterator-or-Tuple functions | ||||||||||||||||||||||
##### `foldl(f, ::Array)` | ||||||||||||||||||||||
##### | ||||||||||||||||||||||
|
||||||||||||||||||||||
# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays, | ||||||||||||||||||||||
# and also provides some alternatives for versions of Julia where iterators weren't supported. | ||||||||||||||||||||||
# Inspired by `Base._reverse`, used in defn of `foldr`. | ||||||||||||||||||||||
# The implementation was originally for both tuples and arrays, although using accumulate | ||||||||||||||||||||||
# to carry intermediate results along creates arrays of tuples which could be avoided. | ||||||||||||||||||||||
# Using a loop can be a few times faster, this should be replaced: | ||||||||||||||||||||||
# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305 | ||||||||||||||||||||||
|
||||||||||||||||||||||
# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps | ||||||||||||||||||||||
# be replaced by _peel1 like Iterators.peel | ||||||||||||||||||||||
# Note also that it does not return a gradient for `init`, now marked `@not_implemented`. | ||||||||||||||||||||||
|
||||||||||||||||||||||
_reverse1(x) = Iterators.reverse(x) | ||||||||||||||||||||||
_drop1(x) = Iterators.drop(x, 1) | ||||||||||||||||||||||
_zip2(x, y) = zip(x, y) # for `accumulate`, below | ||||||||||||||||||||||
|
||||||||||||||||||||||
_reverse1(x::Tuple) = reverse(x) | ||||||||||||||||||||||
_drop1(x::Tuple) = Base.tail(x) | ||||||||||||||||||||||
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N) | ||||||||||||||||||||||
|
||||||||||||||||||||||
struct _InitialValue end # Old versions don't have `Base._InitialValue` | ||||||||||||||||||||||
function rrule( | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple}; | ||||||||||||||||||||||
) where {G} | ||||||||||||||||||||||
start, list = if init === Base._InitialValue() | ||||||||||||||||||||||
Iterators.peel(x) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
# Case with init keyword is simpler to understand first! | ||||||||||||||||||||||
init, x | ||||||||||||||||||||||
end | ||||||||||||||||||||||
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b | ||||||||||||||||||||||
c, back = rrule_via_ad(config, op, a, b) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
y = first(last(hobbits)) | ||||||||||||||||||||||
axe = axes(x) | ||||||||||||||||||||||
project = ProjectTo(x) | ||||||||||||||||||||||
function unfoldl(dy) | ||||||||||||||||||||||
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back) | ||||||||||||||||||||||
ds, da, db = back(dc) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
dop = sum(first, trio) | ||||||||||||||||||||||
dx = map(last, Iterators.reverse(trio)) | ||||||||||||||||||||||
if init === Base._InitialValue() # `hobbits` is one short | ||||||||||||||||||||||
dx = _vcat1(trio[end][2], dx) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
d_init = @not_implemented "gradient for foldl does not at present include init, sorry" | ||||||||||||||||||||||
Comment on lines
+536
to
+539
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Would this work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably! It's been a while, but my memory is that I mostly got tired of making tests, so thought I'd leave that for later. |
||||||||||||||||||||||
return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe))) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return y, unfoldl | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
_vcat1(x, ys::AbstractVector) = vcat(x, ys) | ||||||||||||||||||||||
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys) | ||||||||||||||||||||||
_vcat1(x, ys::Tuple) = (x, ys...) | ||||||||||||||||||||||
|
||||||||||||||||||||||
_reshape1(x::AbstractArray, axe) = reshape(x, axe) | ||||||||||||||||||||||
_reshape1(x::Tuple, axe) = x | ||||||||||||||||||||||
|
||||||||||||||||||||||
_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx) | ||||||||||||||||||||||
_no_tuple_tangent(dx) = dx | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
##### | ||||||||||||||||||||||
##### `accumulate` | ||||||||||||||||||||||
##### | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`. | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient. | ||||||||||||||||||||||
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)` | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we don't at present support getting back a gradient for init, except if it's I am a little uncomfortable putting rules on mutating fuctions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intention was to move both to functions with positional Then it could have a gradient for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do think this is unsafe the same way that The originally envisaged use case was that the 2nd derivative of |
||||||||||||||||||||||
|
||||||||||||||||||||||
function rrule( | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, ::typeof(accumulate), op::G, x::Union{AbstractArray, Tuple}; | ||||||||||||||||||||||
init=_InitialValue(), dims=nothing | ||||||||||||||||||||||
config::RuleConfig{>:HasReverseMode}, | ||||||||||||||||||||||
::typeof(Base._accumulate!), | ||||||||||||||||||||||
op::G, y::AbstractVector, | ||||||||||||||||||||||
x::AbstractVector, | ||||||||||||||||||||||
dims::Nothing, | ||||||||||||||||||||||
init, | ||||||||||||||||||||||
) where {G} | ||||||||||||||||||||||
isnothing(dims) || dims == 1 && x isa Base.AbstractVecOrTuple || throw( | ||||||||||||||||||||||
"accumulate(op, x; dims) is not currently supported by ChainRules, sorry" | ||||||||||||||||||||||
# It's not supported by AD either, so no point calling back, and no regression: | ||||||||||||||||||||||
# gradient(x -> sum(accumulate(/, x, dims=1)), rand(3,4)) | ||||||||||||||||||||||
# ERROR: Mutating arrays is not supported | ||||||||||||||||||||||
) | ||||||||||||||||||||||
list, start = if init === _InitialValue() | ||||||||||||||||||||||
_drop1(x), first(x) | ||||||||||||||||||||||
|
||||||||||||||||||||||
start, list = if init === nothing | ||||||||||||||||||||||
Iterators.peel(x) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
x, init | ||||||||||||||||||||||
something(init), x | ||||||||||||||||||||||
end | ||||||||||||||||||||||
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||
c, back = rrule_via_ad(config, op, a, b) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
y = map(first, hobbits) | ||||||||||||||||||||||
if init === _InitialValue() | ||||||||||||||||||||||
if init === nothing | ||||||||||||||||||||||
# `hobbits` is one short, and first one doesn't invoke `op` | ||||||||||||||||||||||
y = _vcat1(first(x), y) | ||||||||||||||||||||||
y[1] = first(x) | ||||||||||||||||||||||
map!(first, @view(y[2:end]), hobbits) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
map!(first, y, hobbits) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
axe = axes(x) | ||||||||||||||||||||||
project = ProjectTo(x) | ||||||||||||||||||||||
function decumulate(dy) | ||||||||||||||||||||||
dy_plain = _no_tuple_tangent(unthunk(dy)) | ||||||||||||||||||||||
rev_list = if init === _InitialValue() | ||||||||||||||||||||||
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...)) | ||||||||||||||||||||||
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" | ||||||||||||||||||||||
_zip2(_reverse1(hobbits), _reverse1(dy_plain)) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
_zip2(_reverse1(hobbits), _reverse1(dy_plain)) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
dy_plain = unthunk(dy) | ||||||||||||||||||||||
rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain)) | ||||||||||||||||||||||
# Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1)) | ||||||||||||||||||||||
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{" | ||||||||||||||||||||||
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz) | ||||||||||||||||||||||
ds, da, db = back(dc + dz) | ||||||||||||||||||||||
# Don't need to store every 'da', but need for next iteration, and the last one. | ||||||||||||||||||||||
end | ||||||||||||||||||||||
dop = sum(first, trio) | ||||||||||||||||||||||
dx = map(last, _reverse1(trio)) | ||||||||||||||||||||||
if init == _InitialValue() | ||||||||||||||||||||||
dx = map(last, Iterators.reverse(trio)) | ||||||||||||||||||||||
if init == nothing | ||||||||||||||||||||||
# `hobbits` is one short, and the first one is weird | ||||||||||||||||||||||
dx = _vcat1(trio[end][2] + dy_plain[1], dx) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return (NoTangent(), dop, project(_reshape1(dx, axe))) | ||||||||||||||||||||||
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only" | ||||||||||||||||||||||
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry" | ||||||||||||||||||||||
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not) | ||||||||||||||||||||||
return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return _reshape1(y, axe), decumulate | ||||||||||||||||||||||
return reshape(y, axe), decumulate | ||||||||||||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to not capture the accumulated outputs (
c
s) in the pullback? It seems easy enough for tuples usingmap
, but I'm unsure if the extra allocation would be welcomed for arrays.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can write a
for
loop like this: FluxML/Zygote.jl#644 (comment) . IMO this array method should probably be replaced, but not today.Carrying
c
by updating a variable from insideaccumulate
was very slow, IIRC it hits the closure issue.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried that in FluxML/Flux.jl#2003. The main challenges are nested differentiation and handling the case when
typeof(x |> f) != typeof(x |> f |> f)
(you must widen, which means preallocating an array is impossible withoutreturn_type
shenanigans).So, assuming type inference cooperates, the accumulate approach seems no less promising. Would there be any objections to a post-processing step like the following which allows the GC to clean up intermediate outputs before the pullback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the mutation has these problems, but was much quicker, maybe it can be used when safe.
The intention with writing
foldl
in terms ofaccumulate
was to allow for 2nd derivatives, but not sure this actually works right now.Re saving memory, we can add something like
unzip_accumulate(f, xs; init) = StructArrays.components(StructArray(Iterators.accumulate(f, xs; init)))
to free the bits we don't need anymore.But this PR would like to kick the can down the road on such improvements.
(And others -- it returns
@not_implemented
foraccumulate
's init, might be easy to do better, but tired of adding tests... at least it's no longer wrong.)