Skip to content

Commit

Permalink
fixup many unzip things
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 25, 2023
1 parent 5f87eb9 commit 16e03e6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
57 changes: 28 additions & 29 deletions src/unzipped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,8 @@ end
##### map
#####

# `unzip_map` can use `StructArrays.components(StructArray(Iterators.map(f, args...)))`,
# will be useful for the gradient of `map` etc.


"""
unzip_map(f, args...)
unzip_map(f, args...)
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
but performed using `StructArrays` for efficiency.
Expand All @@ -86,40 +82,36 @@ function unzip_map(f::F, args...) where {F}
end

unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...))
# unzip_map(f::F, args::NamedTuple...) where {F} = unzip(map(f, args...))

unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...))

"""
unzip_map_reversed(f, args...)
For a pure function `f` which returns a tuple, this is `== unzip(map(f, args...))`.
But the order of evaluation is should be the reverse.
Does NOT handle `zip`-like behaviour.
"""
function unzip_map_reversed(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
len1 = length(first(args))
if all(a -> length(a)==len1, args)
rev_args = map(Iterators.reverse, args)
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
else
len = minimum(length, args)
rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args)
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
end
return map(reverse!!, outs)
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
return map(reverse!!, unzip_map(f, map(_safereverse, args)...))
end

# This avoids MethodError: no method matching iterate(::Base.Iterators.Reverse{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}) on 1.6
_safereverse(x) = VERSION > v"1.7" ? Iterators.reverse(x) : reverse(x)

function unzip_map_reversed(f::F, args::Tuple...) where {F}
len = minimum(length, args)
rev_args = map(a -> reverse(a[1:len]), args)
# vlen = Val(len)
# rev_args = map(args) do a
# reverse(ntuple(i -> a[i], vlen)) # does not infer better
# end
return map(reverse, unzip(map(f, rev_args...)))
len1 = length(first(args))
all(a -> length(a)==len1, args) || error("unzip_map_reversed does not handle zip-like behaviour.")
return map(reverse, unzip(map(f, map(reverse, args)...)))
end
# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
# rev_args = map(reverse, args)
# return map(reverse, unzip(map(f, rev_args...)))
# end

"""
reverse!!(x)
Expand All @@ -135,10 +127,11 @@ function reverse!!(x::AbstractArray)
end
end
reverse!!(x::AbstractArray{<:AbstractZero}) = x
reverse!!(x) = reverse(x)

frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)
frule((_, xdot), ::typeof(reverse!!), x) = reverse!!(x), reverse!!(xdot)

function rrule(::typeof(reverse!!), x::AbstractArray)
function rrule(::typeof(reverse!!), x)
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
return reverse!!(x), reverse!!_back
end
Expand Down Expand Up @@ -181,10 +174,16 @@ end
Expr(:tuple, each...)
end

unzip(xs::AbstractArray{Tuple{T}}) where {T} = (reinterpret(T, xs),) # best case, no copy
function unzip(xs::AbstractArray{Tuple{T}}) where {T}
if isbitstype(T)
(reinterpret(T, xs),) # best case, no copy
else
(map(only, xs),)
end
end

@generated function unzip(xs::AbstractArray{Ts}) where {Ts<:Tuple}
each = if count(!Base.issingletontype, Ts.parameters) < 2
each = if count(!Base.issingletontype, Ts.parameters) < 2 && all(isbitstype, Ts.parameters)
# good case, no copy of data, some trivial arrays
[Base.issingletontype(T) ? :(similar(xs, $T)) : :(reinterpret($T, xs)) for T in Ts.parameters]
else
Expand Down
22 changes: 18 additions & 4 deletions test/unzipped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed

@testset "unzipped.jl" begin
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map, unzip_map_reversed]
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map]
@test_throws Exception fun(sqrt, 1:3)

@test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6])
Expand All @@ -27,22 +27,32 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
end
@test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
end

@testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed]
check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...))
check_no_inferr(f, args...) = unzip_map(f, args...) == unzip(map(f, args...))

@test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector
@test check_no_inferr(tuple, [1,2,3], (5,6,7))

unzip_map == unzip_map_reversed && continue # does not handle unequal lengths.

@test check(tuple, [1 2; 3 4], [5,6,7])
@test check(tuple, [1 2; 3 4], [5,6,7,8,9,10])

@test check_no_inferr(tuple, [1,2,3], (5,6,7,8))
@test check_no_inferr(tuple, [1,2,3,4], (5,6,7))
@test check_no_inferr(tuple, [1 2;3 4], (5,6,7))
end

@testset "unzip_map_reversed" begin
cnt(x, y) = (x, y) .+ (CNT[] += 1)
CNT = Ref(0)
@test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41])
@test unzip_map_reversed(cnt, [10, 20], [30, 40]) == ([12, 21], [32, 41])
@test CNT[] == 2

CNT = Ref(0)
@test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41))
@test unzip_map_reversed(cnt, (10, 20), (30, 40)) == ((12, 21), (32, 41))
end

@testset "rrules" begin
Expand Down Expand Up @@ -76,6 +86,10 @@ using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed
@test unzip([(1,), (3,), (5,)])[1] isa Base.ReinterpretArray

@test unzip(((1,2), (3,4), (5,6))) == ((1, 3, 5), (2, 4, 6))

# Bug: these cases cannot be done by reinterpret
@test unzip([([1,2],), ([3,4],)]) == ([[1, 2], [3, 4]],)
@test unzip([(nothing, [1,2]), (nothing, [3,4])]) == ([nothing, nothing], [[1, 2], [3, 4]])

# test_rrule(unzip, [(1,2), (3,4), (5.0,6.0)], check_inferred=false) # DimensionMismatch: second dimension of A, 6, does not match length of x, 2

Expand Down

0 comments on commit 16e03e6

Please sign in to comment.