-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rules for map
, zip
and some comprehensions
#671
base: main
Are you sure you want to change the base?
Changes from all commits
a233feb
6981622
0adf5bd
f80ef94
b020a2a
e781ae1
f0722c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
##### | ||
##### Comprehension: Iterators.map | ||
##### | ||
|
||
# Comprehension does guarantee iteration order. Thus its gradient must reverse. | ||
|
||
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} | ||
@debug "collect generator" | ||
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter) | ||
proj_f = ProjectTo(gen.f) | ||
proj_iter = ProjectTo(gen.iter) | ||
function generator_pullback(dys_raw) | ||
dys = unthunk(dys_raw) | ||
dfs, dxs = unzip_map_reversed(|>, dys, backs) | ||
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs))) | ||
end | ||
ys, generator_pullback | ||
end | ||
|
||
##### | ||
##### `zip` | ||
##### | ||
|
||
# Attaching the rule to `zip` breaks Zygote, whose rule is on `Iterators.Zip`. | ||
# function rrule(::typeof(zip), xs::AbstractArray...) | ||
function rrule(::Type{<:Iterators.Zip}, xs::Tuple{Vararg{AbstractArray}}) | ||
function zip_pullback(dy) | ||
@debug "zip array pullback" summary(dy) | ||
dxs = _tangent_unzip(unthunk(dy)) | ||
return (NoTangent(), map(_unmap_pad, xs, dxs)...) | ||
end | ||
function zip_pullback(dy::Tangent) | ||
@debug "zip Tangent pullback" | ||
return (NoTangent(), dy.is...) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
zip_pullback(z::AbstractZero) = (NoTangent(), map(Returns(z), xs)) | ||
return zip(xs...), zip_pullback | ||
end | ||
|
||
_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs)) | ||
_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor | ||
|
||
# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension. | ||
# Closing over `x` lets us re-use ∇getindex. | ||
function _unmap_pad(x::AbstractArray, dx::AbstractArray) | ||
if length(x) == length(dx) | ||
ProjectTo(x)(reshape(dx, axes(x))) | ||
else | ||
@debug "_unmap_pad is extending gradient" length(x) == length(dx) | ||
i1 = firstindex(x) | ||
∇getindex(x, vec(dx), i1:i1+length(dx)-1) | ||
end | ||
end | ||
|
||
# For testing | ||
function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...) | ||
y, back = rrule(Iterators.Zip, xs) | ||
return collect(y), back | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -257,4 +257,40 @@ end | |
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) | ||
end | ||
end | ||
|
||
@testset "map(f, ::Array)" begin | ||
test_rrule(map, identity, [1.0, 2.0], check_inferred=false) | ||
test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false) | ||
test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false) | ||
# @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch | ||
|
||
@test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false) | ||
rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF? | ||
@test_skip test_rrule(map, Multiplier(rand() + im) ⊢ NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent | ||
|
||
y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0]) | ||
@test y1 == [1, 4, 9] | ||
@test bk1([4, 5, 6.0])[3] ≈ 2 .* (1:3) .* (4:6) | ||
|
||
y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0]) | ||
@test y2 == map(Counter(), 11:13) | ||
@test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high | ||
end | ||
|
||
@testset "map(f, ::Array, ::Array)" begin | ||
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} | ||
Comment on lines
+280
to
+281
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for broadcasting, many rules fail inference tests only because of |
||
test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast | ||
|
||
test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) | ||
test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) | ||
|
||
test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false) | ||
|
||
cnt3 = Counter() | ||
y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3]) | ||
@test y3 == 1:3 | ||
@test cnt3 == Counter(3) | ||
z3 = bk3([1, 1, 1000]) | ||
@test z3[3] == [53, 33, 13000] | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
|
||
@testset "Comprehension" begin | ||
@testset "simple" begin | ||
y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0])) | ||
@test y1 == [1,4,9] | ||
t1 = bk1(4:6)[2] | ||
@test t1 isa Tangent{<:Base.Generator} | ||
@test t1.f == NoTangent() | ||
@test t1.iter ≈ 2 .* (1:3) .* (4:6) | ||
|
||
y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) | ||
@test y2 == map(Counter(), 11:13) | ||
@test bk2(ones(3))[2].iter == [33, 23, 13] | ||
end | ||
end | ||
|
||
@testset "Iterators" begin | ||
@testset "zip" begin | ||
test_rrule(collect∘zip, rand(3), rand(3)) | ||
test_rrule(collect∘zip, rand(2,2), rand(2,2), rand(2,2)) | ||
test_rrule(collect∘zip, rand(4), rand(2,2)) | ||
|
||
test_rrule(collect∘zip, rand(3), rand(5)) | ||
test_rrule(collect∘zip, rand(3,2), rand(5)) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since generic broadcasting is slow anyway, maybe I change my mind to thinking it should reverse the order of iteration. Even though the order isn't guaranteed by Julia, perhaps it's better that the rule at least fixes forward & reverse passes to match.