Skip to content

Commit

Permalink
use Base definition of stack
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 15, 2022
1 parent ce7b203 commit 15933c1
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 71 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <[email protected]> and contributors"]
version = "0.2.11"
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
Expand Down
7 changes: 6 additions & 1 deletion src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import ChainRulesCore: rrule
using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
NoTangent, ZeroTangent, ProjectTo

if VERSION < v"1.9.0-DEV.1163"
import Compat: stack
else
import Base: stack
end

include("observation.jl")
export numobs,
Expand Down Expand Up @@ -66,7 +71,7 @@ export batch,
ones_like,
rand_like,
randn_like,
stack,
stack, # in Base since julia v1.9
unbatch,
unsqueeze,
unstack,
Expand Down
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Deprecations v0.1
@deprecate stack(x, dims) stack(x; dims=dims)
@deprecate unstack(x, dims) unstack(x; dims=dims)
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)
Expand Down
66 changes: 0 additions & 66 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,72 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)

Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")

"""
stack(xs; dims)
Concatenate the given array of arrays `xs` into a single array along the
new dimension `dims`. All arrays need to be of the same size.
See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).
# Examples
```jldoctest
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]
julia> stack(xs, dims=1)
3×2 Matrix{Int64}:
1 2
3 4
5 6
julia> stack(xs, dims=2)
2×3 Matrix{Int64}:
1 3 5
2 4 6
julia> stack(xs, dims=3)
2×1×3 Array{Int64, 3}:
[:, :, 1] =
1
2
[:, :, 2] =
3
4
[:, :, 3] =
5
6
```
"""
function stack(xs; dims::Int)
N = ndims(xs[1])
if dims <= N
vs = unsqueeze.(xs; dims)
else
vs = xs
end
if dims == 1
return reduce(vcat, vs)
elseif dims === 2
return reduce(hcat, vs)
else
return reduce((x, y) -> cat(x, y; dims=dims), vs)
end
end

function rrule(::typeof(stack), xs; dims::Int)
function stack_pullback(Δ)
return (NoTangent(), unstack(unthunk(Δ); dims=dims))
end
return stack(xs; dims=dims), stack_pullback
end

"""
unstack(xs; dims)
Expand Down
3 changes: 1 addition & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
x = randn(3,3)
stacked = stack([x, x], dims=2)
@test size(stacked) == (3,2,3)
@test_broken @inferred(stack([x, x], dims=2)) == stacked
@test @inferred(stack([x, x], dims=2)) == stacked

stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
Expand All @@ -30,7 +30,6 @@ end
a = [[1] for i in 1:10000]
@test size(stack(a, dims=1)) == (10000, 1)
@test size(stack(a, dims=2)) == (1, 10000)
@test size(stack(a, dims=3)) == (1, 1, 10000)
end

@testset "batch and unbatch" begin
Expand Down

0 comments on commit 15933c1

Please sign in to comment.