Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rules for map, zip and some comprehensions #671

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl")
include("rulesets/Base/sort.jl")
include("rulesets/Base/mapreduce.jl")
include("rulesets/Base/broadcast.jl")
include("rulesets/Base/iterators.jl")
include("rulesets/Base/CoreLogging.jl")

include("rulesets/Distributed/nondiff.jl")
Expand Down
27 changes: 26 additions & 1 deletion src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3})
end

#####
##### `map`
##### `map(f, ::Tuple...)`
#####

# Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting
Expand Down Expand Up @@ -272,6 +272,31 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
return y, map_pullback
end

#####
##### `map(f, ::AbstractArray...)`
#####

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F}
# Here map agrees with broadcast, and we have a meta-rule with 4 different paths, should be fast:
y, back = rrule_via_ad(cfg, broadcast, f, x)
return y, back
end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
if all(==(size(x)), map(size, ys))
# Here too map agrees with broadcast, maybe the test could be more elegant?
y, back = rrule_via_ad(cfg, broadcast, f, x, ys...)
return y, back
end
@debug "rrule(map, f, arrays...)" f
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...)
function map_pullback_2(dz)
df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs)
return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
end
z, map_pullback_2
end

#####
##### `task_local_storage`
#####
Expand Down
5 changes: 2 additions & 3 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,15 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
end

# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
# Since broadcast makes no guarantee about order of calls, and un-fusing
# can change the number of calls, don't bother to try to reverse the iteration.
# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration.

function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
@debug("split broadcasting generic", f, N)
ys3, backs = unzip_broadcast(args...) do a...
rrule_via_ad(cfg, f, a...)
end
function back_generic(dys)
deltas = unzip_broadcast(backs, dys) do back, dy # (could be map, sizes match)
deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since generic broadcasting is slow anyway, maybe I change my mind to thinking it should reverse the order of iteration. Even though the order isn't guaranteed by Julia, perhaps it's better that the rule at least fixes forward & reverse passes to match.

map(unthunk, back(dy))
end
dargs = map(unbroadcast, args, Base.tail(deltas))
Expand Down
60 changes: 60 additions & 0 deletions src/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#####
##### Comprehension: Iterators.map
#####

# Comprehension does guarantee iteration order. Thus its gradient must reverse.

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator}
@debug "collect generator"
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter)
proj_f = ProjectTo(gen.f)
proj_iter = ProjectTo(gen.iter)
function generator_pullback(dys_raw)
dys = unthunk(dys_raw)
dfs, dxs = unzip_map_reversed(|>, dys, backs)
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs)))
end
ys, generator_pullback
end

#####
##### `zip`
#####

# Attaching the rule to `zip` breaks Zygote, whose rule is on `Iterators.Zip`.
# function rrule(::typeof(zip), xs::AbstractArray...)
function rrule(::Type{<:Iterators.Zip}, xs::Tuple{Vararg{AbstractArray}})
function zip_pullback(dy)
@debug "zip array pullback" summary(dy)
dxs = _tangent_unzip(unthunk(dy))
return (NoTangent(), map(_unmap_pad, xs, dxs)...)
end
function zip_pullback(dy::Tangent)
@debug "zip Tangent pullback"
return (NoTangent(), dy.is...)
end
Copy link
Member Author

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zip_pullback(dy::Tangent) is here because Zygote's rule needed this. Not sure it's tested, nor whether it is in fact required. In trying to cook up examples to hit this, using Diffractor or Yota, I just get errors.

zip_pullback(z::AbstractZero) = (NoTangent(), map(Returns(z), xs))
return zip(xs...), zip_pullback
end

_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs))
_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor

# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension.
# Closing over `x` lets us re-use ∇getindex.
function _unmap_pad(x::AbstractArray, dx::AbstractArray)
if length(x) == length(dx)
ProjectTo(x)(reshape(dx, axes(x)))
else
@debug "_unmap_pad is extending gradient" length(x) == length(dx)
i1 = firstindex(x)
∇getindex(x, vec(dx), i1:i1+length(dx)-1)
end
end

# For testing
function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...)
y, back = rrule(Iterators.Zip, xs)
return collect(y), back
end

81 changes: 77 additions & 4 deletions src/unzipped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,75 @@
##### map
#####

# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
# will be useful for the gradient of `map` etc.
"""
unzip_map(f, args...)

For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
but performed using `StructArrays` for efficiency.
"""
function unzip_map(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
return StructArrays.components(StructArray(Iterators.map(f, args...)))
end

unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...))
# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))

unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...))

"""
unzip_map_reversed(f, args...)

For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
But the order of evaluation is should be the reverse.
Does NOT handle `zip`-like behaviour.
"""
function unzip_map_reversed(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
len1 = length(first(args))
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
return map(reverse!!, unzip_map(f, map(_safereverse, args)...))
end

# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x)

function unzip_map_reversed(f::F, args::Tuple...) where {F}
len1 = length(first(args))
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
return map(reverse, unzip(map(f, map(reverse, args)...)))
end

"""
reverse!!(x)

Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`.
Only safe if you are quite sure nothing else closes over `x`.
"""
function reverse!!(x::AbstractArray)
if ChainRulesCore.is_inplaceable_destination(x)
Base.reverse!(x)
else
Base.reverse(x)
end
end
reverse!!(x::AbstractArray{<:AbstractZero}) = x
reverse!!(x) = reverse(x)

frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot)

function rrule(::typeof(reverse!!), x)
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
return reverse!!(x), reverse!!_back
end


#####
Expand Down Expand Up @@ -107,10 +174,16 @@
Expr(:tuple, each...)
end

unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy
function unzip(xs::AbstractArray{Tuple{T}}) where {T}
if isbitstype(T)
(reinterpret(T, xs),) # best case, no copy
else
(map(only, xs),)
end
end

@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple}
each = if count(!Base.issingletontype, Ts.parameters) < 2
each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters)

Check warning on line 186 in src/unzipped.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: src/unzipped.jl:186:- each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters) src/unzipped.jl:187:- # good case, no copy of data, some trivial arrays src/unzipped.jl:188:- [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] src/unzipped.jl:189:- else src/unzipped.jl:190:- [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] src/unzipped.jl:191:- end src/unzipped.jl:192:- Expr(:tuple, each...) src/unzipped.jl:201:+ each = src/unzipped.jl:202:+ if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters) src/unzipped.jl:203:+ # good case, no copy of data, some trivial arrays src/unzipped.jl:204:+ [ src/unzipped.jl:205:+ Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for src/unzipped.jl:206:+ T in Ts.parameters src/unzipped.jl:207:+ ] src/unzipped.jl:208:+ else src/unzipped.jl:209:+ [:(map($(Get(i)), xs)) for i in 1:length(fieldnames(Ts))] src/unzipped.jl:210:+ end src/unzipped.jl:211:+ return Expr(:tuple, each...)
# good case, no copy of data, some trivial arrays
[Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters]
else
Expand Down
36 changes: 36 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,40 @@ end
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
end
end

@testset "map(f, ::Array)" begin
test_rrule(map, identity, [1.0, 2.0], check_inferred=false)
test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false)
test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false)
# @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch

@test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false)
rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF?
@test_skip test_rrule(map, Multiplier(rand() + im) ⊢ NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent

y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0])
@test y1 == [1, 4, 9]
@test bk1([4, 5, 6.0])[3] ≈ 2 .* (1:3) .* (4:6)

y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0])
@test y2 == map(Counter(), 11:13)
@test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high
end

@testset "map(f, ::Array, ::Array)" begin
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent}
Comment on lines +280 to +281
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for broadcasting, many rules fail inference tests only because of Union{NoTangent, ZeroTangent}. Why do we have two types again, I am unclear as to what difference they encode & why this is worth doing.

test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast

test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false)
test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false)

test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false)

cnt3 = Counter()
y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3])
@test y3 == 1:3
@test cnt3 == Counter(3)
z3 = bk3([1, 1, 1000])
@test z3[3] == [53, 33, 13000]
end
end
26 changes: 26 additions & 0 deletions test/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

@testset "Comprehension" begin
@testset "simple" begin
y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0]))
@test y1 == [1,4,9]
t1 = bk1(4:6)[2]
@test t1 isa Tangent{<:Base.Generator}
@test t1.f == NoTangent()
@test t1.iter ≈ 2 .* (1:3) .* (4:6)

y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0]))
@test y2 == map(Counter(), 11:13)
@test bk2(ones(3))[2].iter == [33, 23, 13]
end
end

@testset "Iterators" begin
@testset "zip" begin
test_rrule(collect∘zip, rand(3), rand(3))
test_rrule(collect∘zip, rand(2,2), rand(2,2), rand(2,2))
test_rrule(collect∘zip, rand(4), rand(2,2))

test_rrule(collect∘zip, rand(3), rand(5))
test_rrule(collect∘zip, rand(3,2), rand(5))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ end
include_test("rulesets/Base/mapreduce.jl")
include_test("rulesets/Base/sort.jl")
include_test("rulesets/Base/broadcast.jl")
include_test("rulesets/Base/iterators.jl")

include_test("unzipped.jl") # used primarily for broadcast

Expand Down
52 changes: 42 additions & 10 deletions test/unzipped.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@

using ChainRules: unzip_broadcast, unzip #, unzip_map
using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed

@testset "unzipped.jl" begin
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map,
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map]
@test_throws Exception fun(sqrt, 1:3)

@test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6])
@test fun(tuple, [1, 10, 100]) == ([1, 10, 100],)
@test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3))
@test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3))
@test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3))
@test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6])
@test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],)
@test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3))
@test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3))
@test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3))

if contains(string(fun), "map")
@test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6])
@test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6])
else
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
@test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
@test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6])
end

if contains(string(fun), "map")
Expand All @@ -24,7 +25,34 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
@test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7))
@test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8))
end
@test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
@test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
end

@testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed]
check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...))
check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...))

@test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector
@test check_no_inferr(tuple, [1,2,3], (5,6,7))

unzip_map == unzip_map_reversed && continue # does not handle unequal lengths.

@test check(tuple, [1 2; 3 4], [5,6,7])
@test check(tuple, [1 2; 3 4], [5,6,7,8,9,10])

@test check_no_inferr(tuple, [1,2,3], (5,6,7,8))
@test check_no_inferr(tuple, [1,2,3,4], (5,6,7))
@test check_no_inferr(tuple, [1 2;3 4], (5,6,7))
end

@testset "unzip_map_reversed" begin
cnt(x, y) = (x, y) .+ (CNT[] += 1)
CNT = Ref(0)
@test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41])
@test CNT[] == 2

CNT = Ref(0)
@test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41))
end

@testset "rrules" begin
Expand Down Expand Up @@ -58,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
@test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray

@test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6))

# Bug: these cases cannot be done by reinterpret
@test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],)
@test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]])

# test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2

Expand Down
Loading