From 16e03e650fabea9bddf1c28c1f305df62f837df6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 Oct 2022 21:09:02 -0400 Subject: [PATCH] fixup many unzip things --- src/unzipped.jl | 57 ++++++++++++++++++++++++------------------------ test/unzipped.jl | 22 +++++++++++++++---- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/unzipped.jl b/src/unzipped.jl index 6fcc7eecf..34b8d8d8f 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -66,12 +66,8 @@ end ##### map ##### -# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`, -# will be useful for the gradient of `map` etc. - - """ -unzip_map(f, args...) + unzip_map(f, args...) For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, but performed using `StructArrays` for efficiency. @@ -86,9 +82,17 @@ function unzip_map(f::F, args...) where {F} end unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...)) +# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...)) unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...)) +""" + unzip_map_reversed(f, args...) + +For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`. +But the order of evaluation is should be the reverse. +Does NOT handle `zip`-like behaviour. +""" function unzip_map_reversed(f::F, args...) where {F} T = Broadcast.combine_eltypes(f, args) if isconcretetype(T) @@ -96,30 +100,18 @@ function unzip_map_reversed(f::F, args...) where {F} but f = $(sprint(show, f)) returns type T = $T""")) end len1 = length(first(args)) - if all(a -> length(a)==len1, args) - rev_args = map(Iterators.reverse, args) - outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) - else - len = minimum(length, args) - rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args) - outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) - end - return map(reverse!!, outs) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse!!, unzip_map(f, map(_safereverse, args)...)) end +# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6 +_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x) + function unzip_map_reversed(f::F, args::Tuple...) where {F} - len = minimum(length, args) - rev_args = map(a -> reverse(a[1:len]), args) - # vlen = Val(len) - # rev_args = map(args) do a - # reverse(ntuple(i -> a[i], vlen)) # does not infer better - # end - return map(reverse, unzip(map(f, rev_args...))) + len1 = length(first(args)) + all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.") + return map(reverse, unzip(map(f, map(reverse, args)...))) end -# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N} -# rev_args = map(reverse, args) -# return map(reverse, unzip(map(f, rev_args...))) -# end """ reverse!!(x) @@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray) end end reverse!!(x::AbstractArray{<:AbstractZero}) = x +reverse!!(x) = reverse(x) -frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) +frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot) -function rrule(::typeof(reverse!!), x::AbstractArray) +function rrule(::typeof(reverse!!), x) reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) return reverse!!(x), reverse!!_back end @@ -181,10 +174,16 @@ end Expr(:tuple, each...) end -unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy +function unzip(xs::AbstractArray{Tuple{T}}) where {T} + if isbitstype(T) + (reinterpret(T, xs),) # best case, no copy + else + (map(only, xs),) + end +end @generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple} - each = if count(!Base.issingletontype, Ts.parameters) < 2 + each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters) # good case, no copy of data, some trivial arrays [Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters] else diff --git a/test/unzipped.jl b/test/unzipped.jl index 1677d7c9d..5d84b652d 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -2,7 +2,7 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @testset "unzipped.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map, unzip_map_reversed] + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map] @test_throws Exception fun(sqrt, 1:3) @test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6]) @@ -27,22 +27,32 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed end @test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector end - + @testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed] check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...)) + check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...)) + @test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector + @test check_no_inferr(tuple, [1,2,3], (5,6,7)) + + unzip_map == unzip_map_reversed && continue # does not handle unequal lengths. + @test check(tuple, [1 2; 3 4], [5,6,7]) @test check(tuple, [1 2; 3 4], [5,6,7,8,9,10]) + + @test check_no_inferr(tuple, [1,2,3], (5,6,7,8)) + @test check_no_inferr(tuple, [1,2,3,4], (5,6,7)) + @test check_no_inferr(tuple, [1 2;3 4], (5,6,7)) end @testset "unzip_map_reversed" begin cnt(x, y) = (x, y) .+ (CNT[] += 1) CNT = Ref(0) - @test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41]) + @test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41]) @test CNT[] == 2 CNT = Ref(0) - @test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41)) + @test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41)) end @testset "rrules" begin @@ -76,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray @test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6)) + + # Bug: these cases cannot be done by reinterpret + @test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],) + @test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]]) # test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2