-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attach rule to mapfoldl_impl
not foldl
#569
base: main
Are you sure you want to change the base?
Conversation
what are the TODOs that would make this nolonger [draft]? |
I have forgotten. But one was to decide how unhappy we are about this:
And in general about hooking on deep inside Base. I didn't see a nicer way to hook onto |
Following up on this, could the |
fff84b5 separates out the Tuple method, for which this way of writing the rule makes more sense. And handles It is, strangely, much slower than tagged version, inside AD. And it breaks Yota. The functions between
julia> xn = Tuple(randn(10));
julia> xm = Tuple(rand(10,10) for _ in 1:10);
# Zygote
julia> @btime Zygote.gradient(x -> foldl(/, x), $xn);
min 47.318 ns, mean 77.223 ns (1 allocation, 16 bytes) # before
min 4.381 μs, mean 4.692 μs (37 allocations, 2.16 KiB) # after -- 100x slower
julia> @btime Zygote.gradient(x -> sum(abs2, foldl(*, x)), $xm);
min 17.667 μs, mean 77.964 μs (53 allocations, 29.52 KiB) # before
min 19.708 μs, mean 23.791 μs (69 allocations, 27.48 KiB) # after
julia> @btime Zygote.gradient(x -> Base.afoldl(/, x...), $xn); # no rule -- much slower
min 130.500 μs, mean 135.423 μs (413 allocations, 16.33 KiB)
julia> @btime Zygote.gradient(x -> sum(abs2, Base.afoldl(*, x...)), $xm);
min 143.500 μs, mean 151.413 μs (384 allocations, 40.30 KiB)
# Diffractor
julia> @btime Diffractor.gradient(x -> foldl(/, x), $xn);
min 29.271 ns, mean 30.017 ns (0 allocations) # before
min 350.632 ns, mean 400.959 ns (6 allocations, 672 bytes) # after -- 10x slower
julia> @btime Diffractor.gradient(x -> sum(abs2, foldl(*, x)), $xm);
min 13.666 μs, mean 16.422 μs (29 allocations, 25.38 KiB) # before
min 162.584 μs, mean 218.275 μs (357 allocations, 168.42 KiB); # after
julia> @btime Diffractor.gradient(x -> Base.afoldl(/, x...), $xn); # no rule -- better than Zygote
min 352.882 ns, mean 419.163 ns (6 allocations, 672 bytes)
julia> @btime Diffractor.gradient(x -> sum(abs2, Base.afoldl(/, x...)), $xm)
min 163.125 μs, mean 204.721 μs (357 allocations, 168.42 KiB)
# Yota
julia> @btime Yota.grad(x -> foldl(/, x), $xn);
min 182.790 ns, mean 657.142 ns (3 allocations, 208 bytes) # before
ERROR: No deriative rule found for op %3 = foldl(/, %2)::Float64, try defining it... # after -- fails
julia> @btime Yota.grad(x -> sum(abs2, foldl(*, x)), $xm);
min 8.583 μs, mean 50.186 μs (21 allocations, 16.19 KiB)
julia> Yota.grad(x -> Base.afoldl(/, x...), xn);
ERROR: syntax: Slot objects should not occur in an AST
# Checking pieces?
julia> yyy = Yota.YotaRuleConfig()
julia> @code_warntype rrule(yyy, foldl, /, xn) # before
julia> @code_warntype rrule(yyy, foldl, /, xn)[2](1.0)
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn) # after
julia> @code_warntype rrule(yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), xn)[2](1.0)
julia> @btime rrule($yyy, foldl, /, $xn)[2](1.0);
min 29.271 ns, mean 30.036 ns (0 allocations)
julia> @btime rrule($yyy, Base.mapfoldl_impl, identity, /, Base._InitialValue(), $xn)[2](1.0);
min 29.271 ns, mean 29.753 ns (0 allocations)
I don't see how. I think you can only opt out of functions which have rules, and those ones need to be called to work.
9af7a64 also adds Maybe also worth noting, moving the rule to |
Trying a bit to track this down, today, I think the slowdown is just some quirk of Zygote's handling of keywords. So it's not the rule's fault. And anything which fixes the using Diffractor, ChainRulesCore
ChainRulesCore._backing_error(::Type{<:Base.Pairs{Symbol}}, ::Type{<:NamedTuple}, _) = nothing
# Solves same error as https://github.com/JuliaDiff/ChainRulesCore.jl/pull/503
xn = Tuple(randn(10));
@btime Diffractor.gradient(x -> foldl(/, x), $xn);
# min 29.313 ns, mean 29.545 ns (0 allocations) before (old rule on foldl)
# min 29.313 ns, mean 29.522 ns (0 allocations) after (new rule on Base.mapfoldl_impl)
@btime Diffractor.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 47.625 μs, mean 53.596 μs (569 allocations, 33.16 KiB) before -- i.e. with no rule, just Base, NB μs
_foldl(op::G, itr; kw...) where {G} = _mapfoldl(identity, op, itr; kw...)
_mapfoldl(f::F, op::G, itr; init=Base._InitialValue()) where {F,G} = Base.mapfoldl_impl(f, op, init, itr)
@btime Diffractor.gradient(x -> _foldl(/, x), $xn);
# min 56.542 μs, mean 62.279 μs (672 allocations, 38.78 KiB) before -- i.e. with no rule, just Base, NB μs
import Zygote
@btime Zygote.gradient(x -> foldl(/, x), $xn);
# min 47.402 ns, mean 48.592 ns (1 allocation, 16 bytes) before
# min 4.482 μs, mean 9.120 μs (37 allocations, 2.16 KiB) after -- this I didn't like, above
# Same with Zygote#master, thus including https://github.com/FluxML/Zygote.jl/pull/1286
@btime Zygote.gradient(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 152.667 μs, mean 157.707 μs (494 allocations, 26.44 KiB) before -- i.e. using no rule, jus Base, NB μs
# min 47.402 ns, mean 82.826 ns (1 allocation, 16 bytes) after -- so the issue is Zygote & keywords
using Yota
@btime Yota.grad(x -> foldl(/, x), $xn);
# min 235.140 ns, mean 251.834 ns (3 allocations, 208 bytes) before
# error afterwards, doesn't track further?
ChainRulesCore.@non_differentiable Base._InitialValue()
@btime Yota.grad(x -> Base.mapfoldl_impl(identity, /, Base._InitialValue(), x), $xn);
# min 231.805 ns, mean 250.267 ns (3 allocations, 208 bytes) after So I think we should merge this, if tests pass etc. |
Zygote tries to diff through the kwsorter definition (i.e. https://docs.julialang.org/en/v1/devdocs/functions/#Keyword-arguments), which includes control flow. It's very difficult to make this type stable because it requires saving a different set of pullbacks for each branch (does anybody know how does Diffractor does this?), but FluxML/Zygote.jl#1195 might help with runtime overhead. |
init, x | ||
end | ||
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b | ||
c, back = rrule_via_ad(config, op, a, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to not capture the accumulated outputs (c
s) in the pullback? It seems easy enough for tuples using map
, but I'm unsure if the extra allocation would be welcomed for arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you can write a for
loop like this: FluxML/Zygote.jl#644 (comment) . IMO this array method should probably be replaced, but not today.
Carrying c
by updating a variable from inside accumulate
was very slow, IIRC it hits the closure issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried that in FluxML/Flux.jl#2003. The main challenges are nested differentiation and handling the case when typeof(x |> f) != typeof(x |> f |> f)
(you must widen, which means preallocating an array is impossible without return_type
shenanigans).
So, assuming type inference cooperates, the accumulate approach seems no less promising. Would there be any objections to a post-processing step like the following which allows the GC to clean up intermediate outputs before the pullback?
# ... y = first(last(hobbits))
# If outputs are (recursively) allocated inline, we're less worried about memory overhead
# and the GC can't free them individually anyways.
if !isbitstype(eltype(hobbits))
hobbits = map(((_, pb)) -> (nothing, pb), hobbits)
end
# axe = axes(x) ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the mutation has these problems, but was much quicker, maybe it can be used when safe.
The intention with writing foldl
in terms of accumulate
was to allow for 2nd derivatives, but not sure this actually works right now.
Re saving memory, we can add something like unzip_accumulate(f, xs; init) = StructArrays.components(StructArray(Iterators.accumulate(f, xs; init)))
to free the bits we don't need anymore.
julia> accumulate([1,2,3], init=(4,5)) do prev, this
this .+ prev
end
3-element Vector{Tuple{Int64, Int64}}:
(5, 6)
(7, 8)
(10, 11)
julia> unzip_accumulate([1,2,3], init=(4,5)) do prev, this
this .+ prev
end
([5, 7, 10], [6, 8, 11])
But this PR would like to kick the can down the road on such improvements.
(And others -- it returns @not_implemented
for accumulate
's init, might be easy to do better, but tired of adding tests... at least it's no longer wrong.)
After looking into Diffractor, I think whatever it does happens outside the actual AD transform (perhaps leaving control flow intact is enough), but the ability to have unused branches/blocks in the keyword sorter pruned in the final IR does wonders for type stability. Inspired by this, FluxML/Zygote.jl#446 (comment) has some thoughts on how we might do something similar there. |
The remaining test failure is 1.8 on x86:
Also happened https://github.com/JuliaDiff/ChainRules.jl/runs/7933271950?check_suite_focus=true with #667 (no longer needed). Or:
|
Can this be merged? It's not the last word, as noted above, but it is a step forwards. |
1b5d7a1
to
6333879
Compare
|
||
##### | ||
##### `accumulate` | ||
##### | ||
|
||
# Like `foldl` this by definition works in order, so it makes sense to allow stateful `f`. | ||
|
||
# Also like `foldl`, the version with a keyword `init` can't easily be given a gradient. | ||
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we don't at present support getting back a gradient for init, except if it's nothing
and then it doesn't matter.
So we might as well put this on accumulate
?
I am a little uncomfortable putting rules on mutating fuctions.
Though perhaps this one is safe as we are always fully overwriting y
and never reading it?
A comment to that effect would be good if so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention was to move both to functions with positional init
, and this mutating function was the best option I could find in Base's dispatch.
Then it could have a gradient for init
. I didn't get around to writing one, mostly got tired of fighting tests. But at least step 1 makes step 2 easier, it would be a small PR. And for now it returns @not_implemented
which is better than a silent zero, in theory at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think this is unsafe the same way that fill!
is unsafe. Except that in practice, I think it's much less likely to cause problems, as anyone who gets to accumulate!
has probably been trained out of hoping that mutation will work.
The originally envisaged use case was that the 2nd derivative of foldl
would involve this accumulate
gradient. But I don't recall whether I ever checked that.
else | ||
x, init | ||
something(init), x | ||
end | ||
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b | |
# "The Hobbit", or "There and Back Again" | |
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can resolve the comment about why we do _accumulate!
and if it is actually safe, i think we can merge this.
If it breaks stuff we can always revert
If you think this is good to go then we can merge it. |
if init === Base._InitialValue() # `hobbits` is one short | ||
dx = _vcat1(trio[end][2], dx) | ||
end | ||
d_init = @not_implemented "gradient for foldl does not at present include init, sorry" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if init === Base._InitialValue() # `hobbits` is one short | |
dx = _vcat1(trio[end][2], dx) | |
end | |
d_init = @not_implemented "gradient for foldl does not at present include init, sorry" | |
if init === Base._InitialValue() # `hobbits` is one short | |
dx = _vcat1(trio[end][2], dx) | |
d_init = NoTangent() | |
else | |
d_init = trio[end][2] | |
end |
Would this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably!
It's been a while, but my memory is that I mostly got tired of making tests, so thought I'd leave that for later.
Co-authored-by: Frames White <[email protected]>
Closes #567, perhaps in the minimal way, by attaching these rules to internal function which take positional arguments. Gradient for
init
is just@not_implemented
for now.One nice effect is that I think
foldr
may work too.One weird effect is that
accumulate!(f, y, x)
will work, silently overwritingy
. It does return a NotImplemented, maybe that helps. Xref #521Non-vector shapes like
accumulate(f, ::Matrix)
take a different path, viaIterators.accumulate
, and will miss the rule. So willaccumulate(f, ::Tuple)
. Maybe for that case Base's code is OK.Closes #672 . Probably closes FluxML/Zygote.jl#1297