Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: generalize pooling implementation and add LP versions
Browse files Browse the repository at this point in the history
avik-pal committed Sep 5, 2024
1 parent e47f063 commit 404969b
Showing 5 changed files with 249 additions and 265 deletions.
3 changes: 3 additions & 0 deletions docs/src/api/Lux/layers.md
Original file line number Diff line number Diff line change
@@ -35,10 +35,13 @@ VariationalHiddenDropout
## Pooling Layers

```@docs
AdaptiveLPPool
AdaptiveMaxPool
AdaptiveMeanPool
GlobalLPPool
GlobalMaxPool
GlobalMeanPool
LPPool
MaxPool
MeanPool
```
6 changes: 4 additions & 2 deletions src/Lux.jl
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ include("layers/basic.jl")
include("layers/containers.jl")
include("layers/normalize.jl")
include("layers/conv.jl")
include("layers/pooling.jl")
include("layers/dropout.jl")
include("layers/recurrent.jl")
include("layers/extension.jl")
@@ -87,8 +88,9 @@ include("distributed/public_api.jl")
# Layers
export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer
export Bilinear, Dense, Embedding, Scale
export Conv, ConvTranspose, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool,
AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle
export Conv, ConvTranspose, Upsample, PixelShuffle
export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool,
AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool
export AlphaDropout, Dropout, VariationalHiddenDropout
export BatchNorm, GroupNorm, InstanceNorm, LayerNorm
export WeightNorm
261 changes: 0 additions & 261 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -42,15 +42,6 @@ CRC.@non_differentiable conv_transpose_dims(::Any...)

conv_transpose(x, weight, cdims) = LuxLib.Impl.∇conv_data(x, weight, cdims)

function compute_adaptive_pooling_dims(x::AbstractArray, outsize)
insize = size(x)[1:(end - 2)]
stride = insize outsize
k = insize .- (outsize .- 1) .* stride
return PoolDims(x, k; padding=0, stride=stride)
end

CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any)

function init_conv_weight(
rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType},
in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A}
@@ -508,255 +499,3 @@ end
function PixelShuffle(r::IntegerType)
return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r)))
end

@doc doc"""
MaxPool(window::NTuple; pad=0, stride=window)
Max pooling layer, which replaces all pixels in a block of size `window` with the maximum
value.
# Arguments
- `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling
`length(window) == 2`
## Keyword Arguments
- `stride`: Should each be either single integer, or a tuple with `N` integers
- `pad`: Specifies the number of elements added to the borders of the data array. It can
be
+ a single integer for equal padding all around,
+ a tuple of `N` integers, to apply the same padding at begin/end of each spatial
dimension,
+ a tuple of `2*N` integers, for asymmetric padding, or
+ the singleton `SamePad()`, to calculate padding such that
`size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial
dimension.
# Extended Help
## Inputs
- `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where
```math
O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor
```
- Empty `NamedTuple()`
See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref),
[`AdaptiveMaxPool`](@ref)
"""
@concrete struct MaxPool <: AbstractLuxLayer
k <: Tuple{Vararg{IntegerType}}
pad <: Tuple{Vararg{IntegerType}}
stride <: Tuple{Vararg{IntegerType}}
end

function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k)
stride = Utils.expand(Val(length(k)), stride)
pad = calc_padding(pad, k, 1, stride)
@argcheck allequal(length, (stride, k))

return MaxPool(k, pad, stride)
end

function (m::MaxPool)(x, _, st::NamedTuple)
return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st
end

function Base.show(io::IO, m::MaxPool)
print(io, "MaxPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride))
print(io, ")")
end

@doc doc"""
MeanPool(window::NTuple; pad=0, stride=window)
Mean pooling layer, which replaces all pixels in a block of size `window` with the mean
value.
# Arguments
- `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling
`length(window) == 2`
## Keyword Arguments
- `stride`: Should each be either single integer, or a tuple with `N` integers
- `pad`: Specifies the number of elements added to the borders of the data array. It can
be
+ a single integer for equal padding all around,
+ a tuple of `N` integers, to apply the same padding at begin/end of each spatial
dimension,
+ a tuple of `2*N` integers, for asymmetric padding, or
+ the singleton `SamePad()`, to calculate padding such that
`size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial
dimension.
# Extended Help
## Inputs
- `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where
```math
O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor
```
- Empty `NamedTuple()`
See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref),
[`AdaptiveMeanPool`](@ref)
"""
@concrete struct MeanPool <: AbstractLuxLayer
k <: Tuple{Vararg{IntegerType}}
pad <: Tuple{Vararg{IntegerType}}
stride <: Tuple{Vararg{IntegerType}}
end

function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k)
stride = Utils.expand(Val(length(k)), stride)
pad = calc_padding(pad, k, 1, stride)
@argcheck allequal(length, (stride, k))

return MeanPool(k, pad, stride)
end

function (m::MeanPool)(x, _, st::NamedTuple)
return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st
end

function Base.show(io::IO, m::MeanPool)
print(io, "MeanPool(", m.k)
all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad))
m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride))
print(io, ")")
end

"""
GlobalMaxPool()
Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing max pooling on the complete (w,h)-shaped feature maps.
## Inputs
- `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(1, ..., 1, C, N)`
- Empty `NamedTuple()`
See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref)
"""
struct GlobalMaxPool <: AbstractLuxLayer end

function (g::GlobalMaxPool)(x, _, st::NamedTuple)
return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st
end

"""
GlobalMeanPool()
Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output,
by performing mean pooling on the complete (w,h)-shaped feature maps.
## Inputs
- `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(1, ..., 1, C, N)`
- Empty `NamedTuple()`
See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref)
"""
struct GlobalMeanPool <: AbstractLuxLayer end

function (g::GlobalMeanPool)(x, _, st::NamedTuple)
return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st
end

"""
AdaptiveMaxPool(out::NTuple)
Adaptive Max Pooling layer. Calculates the necessary window size such that its output has
`size(y)[1:N] == out`.
## Arguments
- `out`: Size of the first `N` dimensions for the output
## Inputs
- `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch
dimensions, after the `N` feature dimensions, where `N = length(out)`.
## Returns
- Output of size `(out..., C, N)`
- Empty `NamedTuple()`
See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref).
"""
struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer
out::O
AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out)
end

function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T}
return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st
end

Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")")

"""
AdaptiveMeanPool(out::NTuple)
Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has
`size(y)[1:N] == out`.
## Arguments
- `out`: Size of the first `N` dimensions for the output
## Inputs
- `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch
dimensions, after the `N` feature dimensions, where `N = length(out)`.
## Returns
- Output of size `(out..., C, N)`
- Empty `NamedTuple()`
See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref).
"""
struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer
out::O
AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out)
end

function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T}
return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st
end

Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")")
4 changes: 2 additions & 2 deletions src/layers/normalize.jl
Original file line number Diff line number Diff line change
@@ -323,8 +323,8 @@ Use `Lux.testmode` during inference.
## Example
```jldoctest
julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10),
InstanceNorm(10, relu))
julia> Chain(Dense(784 => 64), InstanceNorm(64, relu; affine=true), Dense(64 => 10),
InstanceNorm(10, relu; affine=true))
Chain(
layer_1 = Dense(784 => 64), # 50_240 parameters
layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1
240 changes: 240 additions & 0 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
abstract type AbstractPoolMode end

CRC.@non_differentiable (::AbstractPoolMode)(::Any...)

@concrete struct GenericPoolMode <: AbstractPoolMode
kernel_size <: Tuple{Vararg{IntegerType}}
stride <: Tuple{Vararg{IntegerType}}
pad <: Tuple{Vararg{IntegerType}}
dilation <: Tuple{Vararg{IntegerType}}
end

(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation)

struct GlobalPoolMode <: AbstractPoolMode end

(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)])

@concrete struct AdaptivePoolMode <: AbstractPoolMode
out_size <: Tuple{Vararg{IntegerType}}
end

function (m::AdaptivePoolMode)(x)
in_size = size(x)[1:(end - 2)]
stride = in_size m.out_size
kernel_size = in_size .- (m.out_size .- 1) .* stride
return PoolDims(x, kernel_size; padding=0, stride, dilation=1)
end

symbol_to_pool_mode(::StaticSymbol{:generic}) = GenericPoolMode
symbol_to_pool_mode(::StaticSymbol{:global}) = GlobalPoolMode
symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode

abstract type AbstractPoolOp end

struct MaxPoolOp <: AbstractPoolOp end
(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims)

struct MeanPoolOp <: AbstractPoolOp end
(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims)

@concrete struct LpPoolOp <: AbstractPoolOp
p
end
(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p)

symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp()
symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp()
symbol_to_pool_op(::StaticSymbol{:lp}, p) = LpPoolOp(p)

@concrete struct PoolingLayer <: AbstractLuxLayer
mode <: AbstractPoolMode
op <: AbstractPoolOp
end

function PoolingLayer(mode::SymbolType, op::SymbolType,
arg::Union{Nothing, Tuple{Vararg{IntegerType}}}=nothing;
stride=arg, pad=0, dilation=1, p=2)
return PoolingLayer(symbol_to_pool_mode(static(mode)),
symbol_to_pool_op(static(op), p), arg; stride, pad, dilation)
end

function PoolingLayer(::Type{GenericPoolMode}, op::AbstractPoolOp,
kernel_size::Tuple{Vararg{IntegerType}}; stride=kernel_size, pad=0, dilation=1)
stride = Utils.expand(Val(length(kernel_size)), stride)
pad = calc_padding(pad, kernel_size, dilation, stride)
dilation = Utils.expand(Val(length(kernel_size)), dilation)
@argcheck allequal(length, (stride, kernel_size, dilation))

return PoolingLayer(GenericPoolMode(kernel_size, stride, pad, dilation), op)
end

function PoolingLayer(::Type{AdaptivePoolMode}, op::AbstractPoolOp,
out_size::Tuple{Vararg{IntegerType}}; kwargs...)
return PoolingLayer(AdaptivePoolMode(out_size), op)
end

function PoolingLayer(::Type{GlobalPoolMode}, op::AbstractPoolOp, ::Nothing; kwargs...)
return PoolingLayer(GlobalPoolMode(), op)
end

(m::PoolingLayer)(x, _, st::NamedTuple) = m.op(x, m.mode(x)), st

for layer_op in (:Max, :Mean, :LP)
op = Symbol(lowercase(string(layer_op)))

layer_name = Symbol(layer_op, :Pool)
extra_kwargs = layer_op == :LP ? ", p=2" : ""
layer_docstring = """
$(layer_name)(window; stride=window, pad=0, dilation=1$(extra_kwargs))
$(layer_op) Pooling layer, which replaces all pixels in a block of size `window` with
the reduction operation: $(op).
## Arguments
- `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling
`length(window) == 2`
## Keyword Arguments
- `stride`: Should each be either single integer, or a tuple with `N` integers
- `dilation`: Should each be either single integer, or a tuple with `N` integers
- `pad`: Specifies the number of elements added to the borders of the data array. It can
be
+ a single integer for equal padding all around,
+ a tuple of `N` integers, to apply the same padding at begin/end of each spatial
dimension,
+ a tuple of `2*N` integers, for asymmetric padding, or
+ the singleton `SamePad()`, to calculate padding such that
`size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial
dimension.
# Extended Help
## Inputs
- `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where
```math
O_i = \\left\\lfloor\\frac{I_i + p_i + p_{(i + N) \\% |p|} - d_i \\times (k_i - 1)}{s_i} + 1\\right\\rfloor
```
- Empty `NamedTuple()`
"""

global_layer_name = Symbol(:Global, layer_name)
extra_kwargs = layer_op == :LP ? "; p=2" : ""
global_pooling_docstring = """
$(global_layer_name)($(extra_kwargs))
Global $(layer_op) Pooling layer. Transforms `(w, h, c, b)`-shaped input into
`(1, 1, c, b)`-shaped output, by performing mean pooling on the complete `(w, h)`-shaped
feature maps.
## Inputs
- `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)`
## Returns
- Output of the pooling `y` of size `(1, ..., 1, C, N)`
- Empty `NamedTuple()`
"""

adaptive_layer_name = Symbol(:Adaptive, layer_name)
adaptive_pooling_docstring = """
$(adaptive_layer_name)(output_size$(extra_kwargs))
Adaptive $(layer_op) Pooling layer. Calculates the necessary window size such that
its output has `size(y)[1:N] == output_size`.
## Arguments
- `output_size`: Size of the first `N` dimensions for the output
## Inputs
- `x`: Expects as input an array with `ndims(x) == N + 2`, i.e. channel and batch
dimensions, after the `N` feature dimensions, where `N = length(output_size)`.
## Returns
- Output of size `(out..., C, N)`
- Empty `NamedTuple()`
"""

@eval begin
# Generic Pooling Layer
@doc $(layer_docstring) @concrete struct $(layer_name) <:
AbstractLuxWrapperLayer{:layer}
layer <: PoolingLayer
end

function $(layer_name)(
window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2)
return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))),
window; stride, pad, dilation, p))
end

function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name))
kernel_size = m.layer.mode.kernel_size
print(io, string($(Meta.quot(layer_name))), "($(kernel_size)")
pad = m.layer.mode.pad
all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad))
stride = m.layer.mode.stride
stride == kernel_size ||
print(io, ", stride=", PrettyPrinting.tuple_string(stride))
dilation = m.layer.mode.dilation
all(==(1), dilation) ||
print(io, ", dilation=", PrettyPrinting.tuple_string(dilation))
if $(Meta.quot(op)) == :lp
a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p)
end
print(io, ")")
end

# Global Pooling Layer
@doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <:
AbstractLuxWrapperLayer{:layer}
layer <: PoolingLayer
end

function $(global_layer_name)(; p=2)
return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p))
end

function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name))
print(io, string($(Meta.quot(global_layer_name))), "(")
if $(Meta.quot(op)) == :lp
a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p)
end
print(io, ")")
end

# Adaptive Pooling Layer
@doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <:
AbstractLuxWrapperLayer{:layer}
layer <: PoolingLayer
end

function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2)
return $(adaptive_layer_name)(PoolingLayer(
static(:adaptive), $(Meta.quot(op)), out_size; p))
end

function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name))
print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size)
if $(Meta.quot(op)) == :lp
a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p)
end
print(io, ")")
end
end
end

0 comments on commit 404969b

Please sign in to comment.