From a233febe010a72c0fb36e6534316b3ef93b173d5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 18 Aug 2022 10:53:18 -0400 Subject: [PATCH 1/6] rules for zip, map, simple comprehension --- src/ChainRules.jl | 1 + src/rulesets/Base/base.jl | 35 +++++++++- src/rulesets/Base/iterators.jl | 124 +++++++++++++++++++++++++++++++++ src/unzipped.jl | 15 ++++ 4 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 src/rulesets/Base/iterators.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 28e73c166..3bec70cf3 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -43,6 +43,7 @@ include("rulesets/Base/indexing.jl") include("rulesets/Base/sort.jl") include("rulesets/Base/mapreduce.jl") include("rulesets/Base/broadcast.jl") +include("rulesets/Base/iterators.jl") include("rulesets/Distributed/nondiff.jl") diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index aa9489e52..5696b265e 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -204,7 +204,7 @@ function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, ::Val{3}) end ##### -##### `map` +##### `map(f, ::Tuple...)` ##### # Ideally reverse mode should always iterate in reverse order. For `map` and broadcasting @@ -244,6 +244,39 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu return y, map_pullback end +##### +##### `map(f, ::AbstractArray...)` +##### + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} + y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # could be broadcast, but Yota likes this one + return Broadcast.materialize(y), back +end + +# Could accept Any? +# `_unmap_pad` is also used for `zip` +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...) + # z, backs = unzip(map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)) + function map_pullback(dz) + df, dxy... = unzip_map(|>, unthunk(dz), backs) + # df, dxy... = unzip(map(|>, unthunk(dz), backs)) + return (NoTangent(), ProjectTo(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) + end + z, map_pullback +end + +# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} +# z, zip_back = rrule(zip, x, ys...) +# m, map_back = rrule(config, map, Splat(f), z) # maybe this is inefficient? +# function map_pullback(dm) +# _, dsplatf, dz = map_back(dm) +# _, dxys... = zip_back(dz) +# return (NoTangent(), 0, dxys...) +# end +# return m, map_back +# end + ##### ##### `task_local_storage` ##### diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl new file mode 100644 index 000000000..8c801aff2 --- /dev/null +++ b/src/rulesets/Base/iterators.jl @@ -0,0 +1,124 @@ +tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor + +##### +##### Comprehension: Iterators.map +##### + +# Comprehension does guarantee iteration order. Thus its gradient must reverse. + +function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} + # ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) + ys, backs = unzip(map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)) + proj_f = ProjectTo(gen.f) + proj_iter = ProjectTo(gen.iter) + function generator_pullback(dys_raw) + dys = unthunk(dys_raw) + # dfs, dxs = unzip_map(|>, Iterators.reverse(dys), Iterators.reverse(backs)) + dfs, dxs = unzip(map(|>, Iterators.reverse(dys), Iterators.reverse(backs))) + return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(reverse!!(dxs)))) + end + ys, generator_pullback +end + +""" + reverse!!(x) + +Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. +Only safe if you are quite sure nothing else closes over `x`. +""" +function reverse!!(x::AbstractArray) + if ChainRulesCore.is_inplaceable_destination(x) + Base.reverse!(x) + else + Base.reverse(x) + end +end +frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) +function rrule(::typeof(reverse!!), x::AbstractArray) + reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) + return reverse!!(x), reverse!!_back +end + +# Needed for Yota, but shouldn't these be automatic? +ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter) +ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators) + +#= + + Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) +Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) + + Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) +Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) + + Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) +Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally + + Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) +Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) + + Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) +Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) + + +@btime Yota.grad($(rand(1000))) do xs + sum(abs2, [sqrt(x) for x in xs]) +end +# Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB) +# Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB) +# without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB) + +# Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB) + + +@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys + zs = map(xs, ys) do x, y + atan(x/y) + end + sum(abs2, zs) +end +# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB) +# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB) +# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB) + +# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB) + + +=# + + +##### +##### `zip` +##### + + +function rrule(::typeof(zip), xs::AbstractArray...) + function zip_pullback(dy) + @debug "zip array pullback" summary(dy) + dxs = _tangent_unzip(unthunk(dy)) + return (NoTangent(), map(_unmap_pad, xs, dxs)...) + end + function zip_pullback(dy::Tangent) + @debug "zip Tangent pullback" + return (NoTangent(), dy.is...) + end + zip_pullback(z::AbstractZero) = (NoTangent(), map(Returns(z), xs)) + return zip(xs...), zip_pullback +end + +_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs)) +_tangent_unzip(xs::AbstractArray) = unzip(xs) # Diffractor + +function _unmap_pad(x::AbstractArray, dx::AbstractArray) + if length(x) == length(dx) + ProjectTo(x)(reshape(dx, axes(x))) + else + i1 = firstindex(x) + ∇getindex(x, vec(dx), i1:i1+length(dx)-1) + # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx))) + # ProjectTo(x)(reshape(dx2, axes(x))) + end +end + + + diff --git a/src/unzipped.jl b/src/unzipped.jl index fe5875e6f..8da3c30fd 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -70,6 +70,21 @@ end # will be useful for the gradient of `map` etc. +""" +unzip_map(f, args...) + +For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, +but performed using `StructArrays` for efficiency. +""" +function unzip_map(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + return StructArrays.components(StructArray(Iterators.map(f, args...))) +end + ##### ##### unzip ##### From 6981622a4e6c3673e340dffb9e396cb13f16f4a0 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 20 Aug 2022 12:57:26 -0400 Subject: [PATCH 2/6] day two --- src/rulesets/Base/base.jl | 30 +++++------------ src/rulesets/Base/broadcast.jl | 25 ++++++++++++-- src/rulesets/Base/iterators.jl | 41 ++++++++--------------- src/unzipped.jl | 59 +++++++++++++++++++++++++++++++++ test/rulesets/Base/base.jl | 34 +++++++++++++++++++ test/rulesets/Base/iterators.jl | 26 +++++++++++++++ test/runtests.jl | 1 + test/unzipped.jl | 38 +++++++++++++++------ 8 files changed, 193 insertions(+), 61 deletions(-) create mode 100644 test/rulesets/Base/iterators.jl diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 5696b265e..dda4d7d6d 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -249,34 +249,22 @@ end ##### function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} - y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # could be broadcast, but Yota likes this one - return Broadcast.materialize(y), back + # y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one + # return Broadcast.materialize(y), back + y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one + return y, back end -# Could accept Any? -# `_unmap_pad` is also used for `zip` function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + @debug "rrule(map, f, arrays...)" f z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...) - # z, backs = unzip(map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)) - function map_pullback(dz) - df, dxy... = unzip_map(|>, unthunk(dz), backs) - # df, dxy... = unzip(map(|>, unthunk(dz), backs)) - return (NoTangent(), ProjectTo(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) + function map_pullback_2(dz) + df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs) + return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) end - z, map_pullback + z, map_pullback_2 end -# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} -# z, zip_back = rrule(zip, x, ys...) -# m, map_back = rrule(config, map, Splat(f), z) # maybe this is inefficient? -# function map_pullback(dm) -# _, dsplatf, dz = map_back(dm) -# _, dxys... = zip_back(dz) -# return (NoTangent(), 0, dxys...) -# end -# return m, map_back -# end - ##### ##### `task_local_storage` ##### diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 4fb83c4e7..dee50394a 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -120,8 +120,27 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} end # Path 4: The most generic, save all the pullbacks. Can be 1000x slower. -# Since broadcast makes no guarantee about order of calls, and un-fusing -# can change the number of calls, don't bother to try to reverse the iteration. +# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration. + +#= + +julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0]) +┌ Debug: split broadcasting generic +│ f = #69 (generic function with 1 method) +│ N = 1 +└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126 +(14.0, (ZeroTangent(), [2.0, 4.0, 6.0])) + +julia> ENV["JULIA_DEBUG"] = nothing + +julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000))); + min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before + min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed + +julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative + min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB) + +=# function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) @@ -129,7 +148,7 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = unzip_broadcast(backs, dys) do back, dy # (could be map, sizes match) + deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index 8c801aff2..f264b53d0 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -1,4 +1,4 @@ -tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor +tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86 ##### ##### Comprehension: Iterators.map @@ -7,38 +7,18 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor # Comprehension does guarantee iteration order. Thus its gradient must reverse. function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} - # ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) - ys, backs = unzip(map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)) + @debug "collect generator" + ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) proj_f = ProjectTo(gen.f) proj_iter = ProjectTo(gen.iter) function generator_pullback(dys_raw) dys = unthunk(dys_raw) - # dfs, dxs = unzip_map(|>, Iterators.reverse(dys), Iterators.reverse(backs)) - dfs, dxs = unzip(map(|>, Iterators.reverse(dys), Iterators.reverse(backs))) - return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(reverse!!(dxs)))) + dfs, dxs = unzip_map_reversed(|>, dys, backs) + return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs))) end ys, generator_pullback end -""" - reverse!!(x) - -Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. -Only safe if you are quite sure nothing else closes over `x`. -""" -function reverse!!(x::AbstractArray) - if ChainRulesCore.is_inplaceable_destination(x) - Base.reverse!(x) - else - Base.reverse(x) - end -end -frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) -function rrule(::typeof(reverse!!), x::AbstractArray) - reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) - return reverse!!(x), reverse!!_back -end - # Needed for Yota, but shouldn't these be automatic? ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter) ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators) @@ -107,12 +87,15 @@ function rrule(::typeof(zip), xs::AbstractArray...) end _tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs)) -_tangent_unzip(xs::AbstractArray) = unzip(xs) # Diffractor +_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor +# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension. +# Closing over `x` lets us re-use ∇getindex. function _unmap_pad(x::AbstractArray, dx::AbstractArray) if length(x) == length(dx) ProjectTo(x)(reshape(dx, axes(x))) else + @debug "_unmap_pad is extending gradient" length(x) == length(dx) i1 = firstindex(x) ∇getindex(x, vec(dx), i1:i1+length(dx)-1) # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx))) @@ -120,5 +103,9 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray) end end - +# For testing +function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...) + y, back = rrule(zip, xs...) + return collect(y), back +end diff --git a/src/unzipped.jl b/src/unzipped.jl index 8da3c30fd..6fcc7eecf 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -85,6 +85,65 @@ function unzip_map(f::F, args...) where {F} return StructArrays.components(StructArray(Iterators.map(f, args...))) end +unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...)) + +unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...)) + +function unzip_map_reversed(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + len1 = length(first(args)) + if all(a -> length(a)==len1, args) + rev_args = map(Iterators.reverse, args) + outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) + else + len = minimum(length, args) + rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args) + outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) + end + return map(reverse!!, outs) +end + +function unzip_map_reversed(f::F, args::Tuple...) where {F} + len = minimum(length, args) + rev_args = map(a -> reverse(a[1:len]), args) + # vlen = Val(len) + # rev_args = map(args) do a + # reverse(ntuple(i -> a[i], vlen)) # does not infer better + # end + return map(reverse, unzip(map(f, rev_args...))) +end +# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N} +# rev_args = map(reverse, args) +# return map(reverse, unzip(map(f, rev_args...))) +# end + +""" + reverse!!(x) + +Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. +Only safe if you are quite sure nothing else closes over `x`. +""" +function reverse!!(x::AbstractArray) + if ChainRulesCore.is_inplaceable_destination(x) + Base.reverse!(x) + else + Base.reverse(x) + end +end +reverse!!(x::AbstractArray{<:AbstractZero}) = x + +frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) + +function rrule(::typeof(reverse!!), x::AbstractArray) + reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) + return reverse!!(x), reverse!!_back +end + + ##### ##### unzip ##### diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..a7c166b0f 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -229,4 +229,38 @@ test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end + + @testset "map(f, ::Array)" begin + test_rrule(map, identity, [1.0, 2.0], check_inferred=false) + test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false) + test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false) + # @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch + + @test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false) + rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF? + @test_skip test_rrule(map, Multiplier(rand() + im) ⊢ NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent + + y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0]) + @test y1 == [1, 4, 9] + @test bk1([4, 5, 6.0])[3] ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0]) + @test y2 == map(Counter(), 11:13) + @test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high + end + + @testset "map(f, ::Array, ::Array)" begin + test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} + test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) + test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) + + test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false) + + cnt3 = Counter() + y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3]) + @test y3 == 1:3 + @test cnt3 == Counter(3) + z3 = bk3([1, 1, 1000]) + @test z3[3] == [53, 33, 13000] + end end diff --git a/test/rulesets/Base/iterators.jl b/test/rulesets/Base/iterators.jl new file mode 100644 index 000000000..d9060d985 --- /dev/null +++ b/test/rulesets/Base/iterators.jl @@ -0,0 +1,26 @@ + +@testset "Comprehension" begin + @testset "simple" begin + y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0])) + @test y1 == [1,4,9] + t1 = bk1(4:6)[2] + @test t1 isa Tangent{<:Base.Generator} + @test t1.f == NoTangent() + @test t1.iter ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) + @test y2 == map(Counter(), 11:13) + @test bk2(ones(3))[2].iter == [93, 83, 73] + end +end + +@testset "Iterators" begin + @testset "zip" begin + test_rrule(collect∘zip, rand(3), rand(3)) + test_rrule(collect∘zip, rand(2,2), rand(2,2), rand(2,2)) + test_rrule(collect∘zip, rand(4), rand(2,2)) + + test_rrule(collect∘zip, rand(3), rand(5)) + test_rrule(collect∘zip, rand(3,2), rand(5)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a9f25c55c..eec120d4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -62,6 +62,7 @@ end include_test("rulesets/Base/mapreduce.jl") include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") + include_test("rulesets/Base/iterators.jl") include_test("unzipped.jl") # used primarily for broadcast diff --git a/test/unzipped.jl b/test/unzipped.jl index 97aaa23f5..1677d7c9d 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -1,20 +1,21 @@ -using ChainRules: unzip_broadcast, unzip #, unzip_map +using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @testset "unzipped.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map, unzip_map_reversed] @test_throws Exception fun(sqrt, 1:3) - @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) - @test fun(tuple, [1, 10, 100]) == ([1, 10, 100],) - @test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3)) - @test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3)) - @test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3)) + @test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],) + @test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3)) + @test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3)) + @test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3)) if contains(string(fun), "map") - @test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6]) else - @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6]) end if contains(string(fun), "map") @@ -24,7 +25,24 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7)) @test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8)) end - @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + @test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + end + + @testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed] + check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...)) + @test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector + @test check(tuple, [1 2; 3 4], [5,6,7]) + @test check(tuple, [1 2; 3 4], [5,6,7,8,9,10]) + end + + @testset "unzip_map_reversed" begin + cnt(x, y) = (x, y) .+ (CNT[] += 1) + CNT = Ref(0) + @test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41]) + @test CNT[] == 2 + + CNT = Ref(0) + @test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41)) end @testset "rrules" begin From 0adf5bdcae6644e7d6b54a6232ff11ea0c4233c5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 00:19:59 -0400 Subject: [PATCH 3/6] rm tup2, update times --- src/rulesets/Base/base.jl | 2 +- src/rulesets/Base/iterators.jl | 22 +++++++++------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index dda4d7d6d..affcdf21f 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -257,7 +257,7 @@ end function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} @debug "rrule(map, f, arrays...)" f - z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...) + z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...) function map_pullback_2(dz) df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs) return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index f264b53d0..1ec727a36 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -1,5 +1,3 @@ -tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86 - ##### ##### Comprehension: Iterators.map ##### @@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} @debug "collect generator" - ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) + ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter) proj_f = ProjectTo(gen.f) proj_iter = ProjectTo(gen.iter) function generator_pullback(dys_raw) @@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) - Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) -Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) + Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape +Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally @@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3 @btime Yota.grad($(rand(1000))) do xs sum(abs2, [sqrt(x) for x in xs]) end -# Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB) -# Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB) -# without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB) +# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB) +# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB) -# Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB) +# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB) @btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys @@ -57,11 +54,10 @@ end end sum(abs2, zs) end -# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB) -# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB) -# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB) +# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB) +# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB) -# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB) +# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster =# From f80ef943ac1c3fd0871547875f9b09f6d4b8762b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 18 Oct 2022 01:01:02 -0400 Subject: [PATCH 4/6] fixup, rm many comments --- src/rulesets/Base/base.jl | 10 +++++-- src/rulesets/Base/broadcast.jl | 20 -------------- src/rulesets/Base/iterators.jl | 49 --------------------------------- test/rulesets/Base/base.jl | 2 ++ test/rulesets/Base/iterators.jl | 4 +-- 5 files changed, 11 insertions(+), 74 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index affcdf21f..37dcec837 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -249,13 +249,17 @@ end ##### function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} - # y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one - # return Broadcast.materialize(y), back - y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one + # Here map agrees with broadcast, and we have a meta-rule with 4 different paths, should be fast: + y, back = rrule_via_ad(cfg, broadcast, f, x) return y, back end function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + if all(==(size(x)), map(size, ys)) + # Here too map agrees with broadcast, maybe the test could be more elegant? + y, back = rrule_via_ad(cfg, broadcast, f, x, ys...) + return y, back + end @debug "rrule(map, f, arrays...)" f z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...) function map_pullback_2(dz) diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index dee50394a..edc5c041f 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -122,26 +122,6 @@ end # Path 4: The most generic, save all the pullbacks. Can be 1000x slower. # While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration. -#= - -julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0]) -┌ Debug: split broadcasting generic -│ f = #69 (generic function with 1 method) -│ N = 1 -└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126 -(14.0, (ZeroTangent(), [2.0, 4.0, 6.0])) - -julia> ENV["JULIA_DEBUG"] = nothing - -julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000))); - min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before - min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed - -julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative - min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB) - -=# - function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) ys3, backs = unzip_broadcast(args...) do a... diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index 1ec727a36..63df6a3d6 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -17,57 +17,10 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) whe ys, generator_pullback end -# Needed for Yota, but shouldn't these be automatic? -ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter) -ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators) - -#= - - Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3) - - Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape -Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators - - Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally - - Yota.grad(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, [sin(x/y) for (x,y) in zip(xs, 1:2)]), [1,2,3]pi/3) - - Yota.grad(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) -Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3) - - -@btime Yota.grad($(rand(1000))) do xs - sum(abs2, [sqrt(x) for x in xs]) -end -# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB) -# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB) - -# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB) - - -@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys - zs = map(xs, ys) do x, y - atan(x/y) - end - sum(abs2, zs) -end -# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB) -# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB) - -# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster - - -=# - - ##### ##### `zip` ##### - function rrule(::typeof(zip), xs::AbstractArray...) function zip_pullback(dy) @debug "zip array pullback" summary(dy) @@ -94,8 +47,6 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray) @debug "_unmap_pad is extending gradient" length(x) == length(dx) i1 = firstindex(x) ∇getindex(x, vec(dx), i1:i1+length(dx)-1) - # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx))) - # ProjectTo(x)(reshape(dx2, axes(x))) end end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a7c166b0f..830a794dd 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -251,6 +251,8 @@ @testset "map(f, ::Array, ::Array)" begin test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} + test_rrule(map, atan, [1 2; 3.0 4.0], [4 5; 6 7.0], check_inferred=false) # same shape => just broadcast + test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) diff --git a/test/rulesets/Base/iterators.jl b/test/rulesets/Base/iterators.jl index d9060d985..8e22c8ba8 100644 --- a/test/rulesets/Base/iterators.jl +++ b/test/rulesets/Base/iterators.jl @@ -10,7 +10,7 @@ y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) @test y2 == map(Counter(), 11:13) - @test bk2(ones(3))[2].iter == [93, 83, 73] + @test bk2(ones(3))[2].iter == [33, 23, 13] end end @@ -23,4 +23,4 @@ end test_rrule(collect∘zip, rand(3), rand(5)) test_rrule(collect∘zip, rand(3,2), rand(5)) end -end \ No newline at end of file +end From b020a2a32a3cb71fb35cf4da298331a0b7fb5312 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:22:40 -0400 Subject: [PATCH 5/6] move rule to Iterators.Zip to not break Zygote --- src/rulesets/Base/iterators.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index 63df6a3d6..355bd6ce2 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -21,7 +21,9 @@ end ##### `zip` ##### -function rrule(::typeof(zip), xs::AbstractArray...) +# Attaching the rule to `zip` breaks Zygote, whose rule is on `Iterators.Zip`. +# function rrule(::typeof(zip), xs::AbstractArray...) +function rrule(::Type{<:Iterators.Zip}, xs::Tuple{Vararg{AbstractArray}}) function zip_pullback(dy) @debug "zip array pullback" summary(dy) dxs = _tangent_unzip(unthunk(dy)) @@ -52,7 +54,7 @@ end # For testing function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...) - y, back = rrule(zip, xs...) + y, back = rrule(Iterators.Zip, xs) return collect(y), back end From e781ae14a686a1b88aeb53d6814626424ebab661 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Oct 2022 21:09:02 -0400 Subject: [PATCH 6/6] fixup many unzip things --- src/unzipped.jl | 57 ++++++++++++++++++++++++------------------------ test/unzipped.jl | 22 +++++++++++++++---- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/unzipped.jl b/src/unzipped.jl index 6fcc7eecf..34b8d8d8f 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -66,12 +66,8 @@ end ##### map ##### -# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`, -# will be useful for the gradient of `map` etc. - - """ -unzip_map(f, args...) + unzip_map(f, args...) For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, but performed using `StructArrays` for efficiency. @@ -86,9 +82,17 @@ function unzip_map(f::F, args...) where {F} end unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...)) +# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...)) unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...)) +""" + unzip_map_reversed(f, args...) + +For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`. +But the order of evaluation is should be the reverse. +Does NOT handle `zip`-like behaviour. +""" function unzip_map_reversed(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) if isconcretetype(T) @@ -96,30 +100,18 @@ function unzip_map_reversed(f::F, args...) where {F} but f = $(sprint(show, f)) returns type T = $T""")) end len1 = length(first(args)) - if all(a -> length(a)==len1, args) - rev_args = map(Iterators.reverse, args) - outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) - else - len = minimum(length, args) - rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args) - outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) - end - return map(reverse!!, outs) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse!!, unzip_map(f, map(_safereverse, args)...)) end +# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6 +_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x) + function unzip_map_reversed(f::F, args::Tuple...) where {F} - len = minimum(length, args) - rev_args = map(a -> reverse(a[1:len]), args) - # vlen = Val(len) - # rev_args = map(args) do a - # reverse(ntuple(i -> a[i], vlen)) # does not infer better - # end - return map(reverse, unzip(map(f, rev_args...))) + len1 = length(first(args)) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse, unzip(map(f, map(reverse, args)...))) end -# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N} -# rev_args = map(reverse, args) -# return map(reverse, unzip(map(f, rev_args...))) -# end """ reverse!!(x) @@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray) end end reverse!!(x::AbstractArray{<:AbstractZero}) = x +reverse!!(x) = reverse(x) -frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) +frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot) -function rrule(::typeof(reverse!!), x::AbstractArray) +function rrule(::typeof(reverse!!), x) reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) return reverse!!(x), reverse!!_back end @@ -181,10 +174,16 @@ end Expr(:tuple, each...) end -unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy +function unzip(xs::AbstractArray{Tuple{T}}) where {T} + if isbitstype(T) + (reinterpret(T, xs),) # best case, no copy + else + (map(only, xs),) + end +end @generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} - each = if count(!Base.issingletontype, Ts.parameters) < 2 + each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters) # good case, no copy of data, some trivial arrays [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] else diff --git a/test/unzipped.jl b/test/unzipped.jl index 1677d7c9d..5d84b652d 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -2,7 +2,7 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @testset "unzipped.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map, unzip_map_reversed] + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map] @test_throws Exception fun(sqrt, 1:3) @test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6]) @@ -27,22 +27,32 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed end @test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector end - + @testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed] check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...)) + check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...)) + @test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector + @test check_no_inferr(tuple, [1,2,3], (5,6,7)) + + unzip_map == unzip_map_reversed && continue # does not handle unequal lengths. + @test check(tuple, [1 2; 3 4], [5,6,7]) @test check(tuple, [1 2; 3 4], [5,6,7,8,9,10]) + + @test check_no_inferr(tuple, [1,2,3], (5,6,7,8)) + @test check_no_inferr(tuple, [1,2,3,4], (5,6,7)) + @test check_no_inferr(tuple, [1 2;3 4], (5,6,7)) end @testset "unzip_map_reversed" begin cnt(x, y) = (x, y) .+ (CNT[] += 1) CNT = Ref(0) - @test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41]) + @test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41]) @test CNT[] == 2 CNT = Ref(0) - @test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41)) + @test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41)) end @testset "rrules" begin @@ -76,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray @test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6)) + + # Bug: these cases cannot be done by reinterpret + @test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],) + @test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]]) # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2