Skip to content

Commit

Permalink
improvements to stack
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 15, 2022
1 parent e247fb5 commit 9d5dde1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Return `x` reshaped into an array one dimensionality higher than `x`,
where `dims` indicates in which dimension `x` is extended.
`dims` can be an integer between 1 and `ndims(x)+1`.
See also [`flatten`](@ref), [`stack`](@ref).
Expand Down Expand Up @@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
[1, 2] [3, 4] [5, 6]
```
"""
function unsqueeze(x::AbstractArray; dims::Int)
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1)
function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N}
# @assert 1 <= dims <= N + 1
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1)
return reshape(x, sz)
end

Expand All @@ -59,9 +61,11 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io,
stack(xs; dims)
Concatenate the given array of arrays `xs` into a single array along the
given dimension `dims`.
given dimension `dims`. All arrays need to be of the same size.
The number of dimension in the final arrays is one more than the number
of dimensions in the input arrays.
See also [`stack`](@ref) and [`batch`](@ref).
See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).
# Examples
Expand Down Expand Up @@ -98,7 +102,28 @@ julia> stack(xs, dims=3)
6
```
"""
stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims)
function stack(xs; dims::Int)
N = ndims(xs[1])
if dims <= N
vs = unsqueeze.(xs; dims)
else
vs = xs
end
if dims == 1
return reduce(vcat, vs)
elseif dims === 2
return reduce(hcat, vs)
else
return reduce((x, y) -> cat(x, y; dims=dims), vs)
end
end

function rrule(::typeof(stack), xs; dims::Int)
function stack_pullback(Δ)
return (NoTangent(), unstack(unthunk(Δ); dims=dims))
end
return stack(xs; dims=dims), stack_pullback
end

"""
unstack(xs; dims)
Expand Down
13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@

"""
Test gradients through zygote.
# Arguments
- `f`: function to test
- `xs`: inputs to `f`
# Keyword Arguments
Keyword arguments are passed to `rrule`.
- `fkwargs`: keyword arguments to `f`
"""
function test_zygote(f, xs...; kws...)
config = ZygoteRuleConfig()
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)
Expand Down
12 changes: 12 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
@test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1)

@test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2)

@test_throws AssertionError unsqueeze(rand(2,2), dims=4)
end

@testset "stack and unstack" begin
Expand All @@ -19,6 +21,16 @@ end
@test unstack(stacked_array, dims=2) == unstacked_array
@test stack(unstacked_array, dims=2) == stacked_array
@test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array

for d in (1,2,3)
test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false)
end

# Issue #121
a = [[1] for i in 1:10000]
@test size(stack(a, dims=1)) == (10000, 1)
@test size(stack(a, dims=2)) == (1, 10000)
@test size(stack(a, dims=3)) == (1, 1, 10000)
end

@testset "batch and unbatch" begin
Expand Down

0 comments on commit 9d5dde1

Please sign in to comment.