Skip to content

Commit

Permalink
squash PR 1407, eleven commits, 2020
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott committed Jan 22, 2021
1 parent 5483a12 commit fa93442
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 179 deletions.
1 change: 0 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ using CUDA
const use_cuda = Ref(false)

include("utils.jl")
include("zeros.jl")
include("onehot.jl")
include("functor.jl")

Expand Down
17 changes: 14 additions & 3 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)

@deprecate outdims(f, inputsize) outputsize(f, inputsize)
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
@deprecate ConvTranspose(; weight, bias, activation=identity, kws...) ConvTranspose(weight, bias, activation; kws...)
@deprecate DepthwiseConv(; weight, bias, activation=identity, kws...) DepthwiseConv(weight, bias, activation; kws...)


# Was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros())
struct Zeros
function Zeros()
@warn "Zeros() is deprecated, please simply use bias=false instead" maxlog=3
false
end
end
Zeros(args...) = Zeros()
56 changes: 30 additions & 26 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,69 +67,73 @@ end
extraChain(::Tuple{}, x) = ()



"""
Dense(in, out, σ=identity; initW=glorot_uniform, initb=zeros, bias=true)
Dense(in, out, σ=identity; bias=true)
Dense(W, b, σ=identity)
Create a traditional `Dense` layer with in×out weight matrix `W` and
Create a traditional `Dense` layer with in×out weight matrix `W` and
bias vector `b` of length `out`. The forward pass is given by:
y = σ.(W * x .+ b)
The input `x` must be a vector of length `in`, a batch of vectors represented
as an `in × N` matrix, or a higher order tensor where all dimensions
after the first one will be treated as batch dimensions.
The out `y` will be a vector of length `out` or a batch whose first
dimension is `out` and the remaining dimensions are the same as in the input.
The input `x` must be a vector of length `in`, or batch of vectors represented
as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out`, or a batch with `size(y) == (out, size(x)[2:end]...)`
Setting `bias` to `false` will switch the bias off for the layer.
Setting `bias=false` creates a layer without bias parameters.
`initW` and `initb` are callables used to initialize weights and biases respectively,
through the calls `initW(out, in)` and `initb(out)`.
Two additional keywords `initW=glorot_uniform` and `initb=Flux.zeros` control the
initialisation of parameters, when using the first constructor.
# Examples
```julia-repl
```jldoctest
julia> d = Dense(5, 2)
Dense(5, 2)
julia> d(rand(Float32, 5))
2-element Array{Float32,1}:
-0.16210233
0.123119034
julia> d(rand(Float32, 5, 64)) |> size
(2, 64)
julia> d = Dense(5, 2; bias=false)
Dense(5, 2)
julia> d1 = Dense(ones(2,5), false, tanh)
Dense(5, 2, tanh; bias=false)
julia> d1(ones(5))
2-element Array{Float64,1}:
0.9999092042625951
0.9999092042625951
julia> params(d1) # no trainable bias
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
```
"""
struct Dense{F,S<:AbstractArray,T<:Union{Zeros, AbstractVector}}
struct Dense{F,S<:AbstractMatrix,T}
W::S
b::T
σ::F
end

Dense(W, b) = Dense(W, b, identity)
Dense(W::AbstractMatrix, b) = Dense(W, b, identity)

function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros, bias=true)
return Dense(initW(out, in), create_bias(bias, initb, out), σ)
initW = glorot_uniform, initb = zeros, bias = true)
W = initW(out, in)
b = create_bias(bias, initb, out)
Dense(W, b, σ)
end

@functor Dense

function (a::Dense)(x::AbstractArray)
W, b, σ = a.W, a.b, a.σ
sz = size(x)
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
x = reshape(x, sz[1], :) # reshape to handle dims > 1 as batch dimensions
x = σ.(W*x .+ b)
return reshape(x, :, sz[2:end]...)
reshape(x, :, sz[2:end]...)
end

function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
l.b == false && print(io, "; bias=false")
print(io, ")")
end

Expand Down
34 changes: 17 additions & 17 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]
expand(N, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)

conv_reshape_bias(c) = c.bias isa AbstractVector ?
reshape(c.bias, map(_->1, c.stride)..., :, 1) :
c.bias

"""
SamePad()
Expand Down Expand Up @@ -96,7 +100,7 @@ end

"""
Conv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
Constructs a convolutional layer with the given weight and bias.
Accepts the same keywords (and has the same defaults) as the `Conv((4,4), 3=>7, relu)`
method.
Expand All @@ -117,7 +121,7 @@ julia> params(c1) |> length
2
```
"""
function Conv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
function Conv(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
Expand Down Expand Up @@ -152,9 +156,8 @@ convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
function (c::Conv)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(conv(x, c.weight, cdims) .+ b)
(c.σ).(conv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::Conv)
Expand Down Expand Up @@ -207,16 +210,16 @@ end

"""
ConvTranspose(weight::AbstractArray, bias, [activation; stride, pad, dilation])
Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `ConvTranspose((4,4), 3=>7, relu)` method.
"""
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
function ConvTranspose(w::AbstractArray{T,N}, b::Union{Bool, AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
bias = create_bias(b, zeros, size(w, N-1))
bias = create_bias(b, zeros, size(w, N-1))
return ConvTranspose(σ, w, bias, stride, pad, dilation)
end

Expand Down Expand Up @@ -248,9 +251,8 @@ end

function (c::ConvTranspose)(x::AbstractArray)
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = conv_transpose_dims(c, x)
σ.(∇conv_data(x, c.weight, cdims) .+ b)
(c.σ).(∇conv_data(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::ConvTranspose)
Expand Down Expand Up @@ -304,11 +306,11 @@ end

"""
DepthwiseConv(weight::AbstractArray, bias, [activation; stride, pad, dilation])
Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `DepthwiseConv((4,4), 3=>6, relu)` method.
"""
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
function DepthwiseConv(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}}, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
Expand Down Expand Up @@ -341,9 +343,8 @@ depthwiseconvfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
init = glorot_uniform) where N = init(filter..., div(ch[2], ch[1]), ch[1])

function (c::DepthwiseConv)(x)
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DepthwiseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(depthwiseconv(x, c.weight, cdims) .+ b)
(c.σ).(depthwiseconv(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::DepthwiseConv)
Expand Down Expand Up @@ -392,11 +393,11 @@ end

"""
CrossCor(weight::AbstractArray, bias, [activation; stride, pad, dilation])
Constructs a layer with the given weight and bias arrays.
Accepts the same keywords as the `CrossCor((4,4), 3=>7, relu)` method.
"""
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool, Zeros, AbstractVector{T}}, σ = identity;
function CrossCor(w::AbstractArray{T,N}, b::Union{Bool,AbstractVector{T}} = true, σ = identity;
stride = 1, pad = 0, dilation = 1) where {T,N}
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
Expand All @@ -422,9 +423,8 @@ end
function (c::CrossCor)(x::AbstractArray)
# TODO: breaks gpu broadcast :(
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
σ.(crosscor(x, c.weight, cdims) .+ b)
(c.σ).(crosscor(x, c.weight, cdims) .+ conv_reshape_bias(c))
end

function Base.show(io::IO, l::CrossCor)
Expand Down
21 changes: 12 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,21 @@ ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)

"""
create_bias(shallcreate::Bool, iftrue, dims...)
create_bias(x, ::Any...)
create_bias(bias::Bool, iftrue, dims...)
Return a bias parameter for a layer.
Return a bias parameter for a layer, based the value given
to the constructor's keyword `bias=bias`.
Essentially handles the allowed input options for the `bias` keyword:
If `false`: Return the `Zeros` type which turns bias off.
If `true` : Return the result of `iftrue(dims)`.
If not a boolean, return self to handle the case of bias=somearray.
* `bias == true` creates `iftrue(dims...)`, typically a dense vector of zeros.
* `bias == false` returns `false`, to indicate no trainable bias.
* `bias::AbstractArray` uses the array provided. It checks size but not eltype.
"""
create_bias(shallcreate::Bool, iftrue, dims...) = shallcreate ? iftrue(dims...) : Zeros()
create_bias(x, ::Any...) = x
function create_bias(bias, iftrue, dims...)
bias===true && return iftrue(dims...)
bias===false && return false
size(bias) == dims || throw(DimensionMismatch("expected bias of size $dims, but got $(size(bias))"))
return bias
end

"""
unsqueeze(xs, dim)
Expand Down
49 changes: 0 additions & 49 deletions src/zeros.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using Random
Nesterov(), RMSProp(), Momentum()]
Random.seed!(42)
w′ = randn(10, 10)
b = Flux.Zeros()
b = false
loss(x) = Flux.Losses.mse(w*x, w′*x .+ b)
for t = 1: 10^5
θ = params([w′, b])
Expand Down
Loading

0 comments on commit fa93442

Please sign in to comment.