diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..ed25ed23e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Base/broadcast.jl") +include("rulesets/Base/iterators.jl") include("rulesets/Base/CoreLogging.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6c66d19ee..83fad3f11 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -232,7 +232,7 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3}) end ##### -##### `map` +##### `map(f, ::Tuple...)` ##### # Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting @@ -272,6 +272,31 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu return y, map_pullback end +##### +##### `map(f, ::AbstractArray...)` +##### + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} + # Here map agrees with broadcast, and we have a meta-rule with 4 different paths, should be fast: + y, back = rrule_via_ad(cfg, broadcast, f, x) + return y, back +end + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + if all(==(size(x)), map(size, ys)) + # Here too map agrees with broadcast, maybe the test could be more elegant? + y, back = rrule_via_ad(cfg, broadcast, f, x, ys...) + return y, back + end + @debug "rrule(map, f, arrays...)" f + z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...) + function map_pullback_2(dz) + df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs) + return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) + end + z, map_pullback_2 +end + ##### ##### `task_local_storage` ##### diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 4fb83c4e7..edc5c041f 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -120,8 +120,7 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} end # Path 4: The most generic, save all the pullbacks. Can be 1000x slower. -# Since broadcast makes no guarantee about order of calls, and un-fusing -# can change the number of calls, don't bother to try to reverse the iteration. +# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration. function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) @@ -129,7 +128,7 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = unzip_broadcast(backs, dys) do back, dy # (could be map, sizes match) + deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl new file mode 100644 index 000000000..355bd6ce2 --- /dev/null +++ b/src/rulesets/Base/iterators.jl @@ -0,0 +1,60 @@ +##### +##### Comprehension: Iterators.map +##### + +# Comprehension does guarantee iteration order. Thus its gradient must reverse. + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} + @debug "collect generator" + ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter) + proj_f = ProjectTo(gen.f) + proj_iter = ProjectTo(gen.iter) + function generator_pullback(dys_raw) + dys = unthunk(dys_raw) + dfs, dxs = unzip_map_reversed(|>, dys, backs) + return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs))) + end + ys, generator_pullback +end + +##### +##### `zip` +##### + +# Attaching the rule to `zip` breaks Zygote, whose rule is on `Iterators.Zip`. +# function rrule(::typeof(zip), xs::AbstractArray...) +function rrule(::Type{<:Iterators.Zip}, xs::Tuple{Vararg{AbstractArray}}) + function zip_pullback(dy) + @debug "zip array pullback" summary(dy) + dxs = _tangent_unzip(unthunk(dy)) + return (NoTangent(), map(_unmap_pad, xs, dxs)...) + end + function zip_pullback(dy::Tangent) + @debug "zip Tangent pullback" + return (NoTangent(), dy.is...) + end + zip_pullback(z::AbstractZero) = (NoTangent(), map(Returns(z), xs)) + return zip(xs...), zip_pullback +end + +_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs)) +_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor + +# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension. +# Closing over `x` lets us re-use ∇getindex. +function _unmap_pad(x::AbstractArray, dx::AbstractArray) + if length(x) == length(dx) + ProjectTo(x)(reshape(dx, axes(x))) + else + @debug "_unmap_pad is extending gradient" length(x) == length(dx) + i1 = firstindex(x) + ∇getindex(x, vec(dx), i1:i1+length(dx)-1) + end +end + +# For testing +function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...) + y, back = rrule(Iterators.Zip, xs) + return collect(y), back +end + diff --git a/src/unzipped.jl b/src/unzipped.jl index fe5875e6f..34b8d8d8f 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -66,8 +66,75 @@ end ##### map ##### -# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`, -# will be useful for the gradient of `map` etc. +""" + unzip_map(f, args...) + +For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, +but performed using `StructArrays` for efficiency. +""" +function unzip_map(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + return StructArrays.components(StructArray(Iterators.map(f, args...))) +end + +unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...)) +# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...)) + +unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...)) + +""" + unzip_map_reversed(f, args...) + +For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`. +But the order of evaluation is should be the reverse. +Does NOT handle `zip`-like behaviour. +""" +function unzip_map_reversed(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + len1 = length(first(args)) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse!!, unzip_map(f, map(_safereverse, args)...)) +end + +# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6 +_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x) + +function unzip_map_reversed(f::F, args::Tuple...) where {F} + len1 = length(first(args)) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse, unzip(map(f, map(reverse, args)...))) +end + +""" + reverse!!(x) + +Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. +Only safe if you are quite sure nothing else closes over `x`. +""" +function reverse!!(x::AbstractArray) + if ChainRulesCore.is_inplaceable_destination(x) + Base.reverse!(x) + else + Base.reverse(x) + end +end +reverse!!(x::AbstractArray{<:AbstractZero}) = x +reverse!!(x) = reverse(x) + +frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot) + +function rrule(::typeof(reverse!!), x) + reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) + return reverse!!(x), reverse!!_back +end ##### @@ -107,10 +174,16 @@ end Expr(:tuple, each...) end -unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy +function unzip(xs::AbstractArray{Tuple{T}}) where {T} + if isbitstype(T) + (reinterpret(T, xs),) # best case, no copy + else + (map(only, xs),) + end +end @generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} - each = if count(!Base.issingletontype, Ts.parameters) < 2 + each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters) # good case, no copy of data, some trivial arrays [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] else diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 25c755f55..6aa2567b3 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -257,4 +257,40 @@ end test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end + + @testset "map(f, ::Array)" begin + test_rrule(map, identity, [1.0, 2.0], check_inferred=false) + test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false) + test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false) + # @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch + + @test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false) + rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF? + @test_skip test_rrule(map, Multiplier(rand() + im) ⊢ NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent + + y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0]) + @test y1 == [1, 4, 9] + @test bk1([4, 5, 6.0])[3] ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0]) + @test y2 == map(Counter(), 11:13) + @test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high + end + + @testset "map(f, ::Array, ::Array)" begin + test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} + test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast + + test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) + test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) + + test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false) + + cnt3 = Counter() + y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3]) + @test y3 == 1:3 + @test cnt3 == Counter(3) + z3 = bk3([1, 1, 1000]) + @test z3[3] == [53, 33, 13000] + end end diff --git a/test/rulesets/Base/iterators.jl b/test/rulesets/Base/iterators.jl new file mode 100644 index 000000000..8e22c8ba8 --- /dev/null +++ b/test/rulesets/Base/iterators.jl @@ -0,0 +1,26 @@ + +@testset "Comprehension" begin + @testset "simple" begin + y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0])) + @test y1 == [1,4,9] + t1 = bk1(4:6)[2] + @test t1 isa Tangent{<:Base.Generator} + @test t1.f == NoTangent() + @test t1.iter ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) + @test y2 == map(Counter(), 11:13) + @test bk2(ones(3))[2].iter == [33, 23, 13] + end +end + +@testset "Iterators" begin + @testset "zip" begin + test_rrule(collect∘zip, rand(3), rand(3)) + test_rrule(collect∘zip, rand(2,2), rand(2,2), rand(2,2)) + test_rrule(collect∘zip, rand(4), rand(2,2)) + + test_rrule(collect∘zip, rand(3), rand(5)) + test_rrule(collect∘zip, rand(3,2), rand(5)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..d2ffc34f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,7 @@ end include_test("rulesets/Base/mapreduce.jl") include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") + include_test("rulesets/Base/iterators.jl") include_test("unzipped.jl") # used primarily for broadcast diff --git a/test/unzipped.jl b/test/unzipped.jl index 4215f3a6e..d705decb5 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -1,20 +1,21 @@ -using ChainRules: unzip_broadcast, unzip #, unzip_map +using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @testset "unzipped.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map] @test_throws Exception fun(sqrt, 1:3) - @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) - @test fun(tuple, [1, 10, 100]) == ([1, 10, 100],) - @test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3)) - @test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3)) - @test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3)) + @test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],) + @test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3)) + @test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3)) + @test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3)) if contains(string(fun), "map") - @test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6]) else - @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6]) end if contains(string(fun), "map") @@ -24,7 +25,34 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7)) @test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8)) end - @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + @test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + end + + @testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed] + check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...)) + check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...)) + + @test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector + @test check_no_inferr(tuple, [1,2,3], (5,6,7)) + + unzip_map == unzip_map_reversed && continue # does not handle unequal lengths. + + @test check(tuple, [1 2; 3 4], [5,6,7]) + @test check(tuple, [1 2; 3 4], [5,6,7,8,9,10]) + + @test check_no_inferr(tuple, [1,2,3], (5,6,7,8)) + @test check_no_inferr(tuple, [1,2,3,4], (5,6,7)) + @test check_no_inferr(tuple, [1 2;3 4], (5,6,7)) + end + + @testset "unzip_map_reversed" begin + cnt(x, y) = (x, y) .+ (CNT[] += 1) + CNT = Ref(0) + @test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41]) + @test CNT[] == 2 + + CNT = Ref(0) + @test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41)) end @testset "rrules" begin @@ -58,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray @test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6)) + + # Bug: these cases cannot be done by reinterpret + @test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],) + @test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]]) # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2