diff --git a/src/utils.jl b/src/utils.jl index 81d3147..915ef4d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,7 +35,7 @@ julia> unsqueeze(xs, dims=1) ``` """ function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N} - # @assert 1 <= dims <= N + 1 + @assert 1 <= dims <= N + 1 sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1) return reshape(x, sz) end @@ -61,9 +61,7 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, stack(xs; dims) Concatenate the given array of arrays `xs` into a single array along the -given dimension `dims`. All arrays need to be of the same size. -The number of dimension in the final arrays is one more than the number -of dimensions in the input arrays. +new dimension `dims`. All arrays need to be of the same size. See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref).