Skip to content

Commit

Permalink
unifying docstrings layout
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 23, 2025
1 parent 729b39a commit 9194d61
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 40 deletions.
10 changes: 5 additions & 5 deletions src/cells/fastrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://arxiv.org/abs/1901.02358
@doc raw"""
FastRNNCell((input_size => hidden_size), [activation];
FastRNNCell(input_size => hidden_size, [activation];
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -84,8 +84,8 @@ function Base.show(io::IO, fastrnn::FastRNNCell)
end

@doc raw"""
FastRNN((input_size => hidden_size), [activation];
return_state = false, kwargs...)
FastRNN(input_size => hidden_size, [activation];
return_state = false, kwargs...)
[Fast recurrent neural network](https://arxiv.org/abs/1901.02358).
See [`FastRNNCell`](@ref) for a layer that processes a single sequences.
Expand Down Expand Up @@ -150,7 +150,7 @@ function Base.show(io::IO, fastrnn::FastRNN)
end

@doc raw"""
FastGRNNCell((input_size => hidden_size), [activation];
FastGRNNCell(input_size => hidden_size, [activation];
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -240,7 +240,7 @@ function Base.show(io::IO, fastgrnn::FastGRNNCell)
end

@doc raw"""
FastGRNN((input_size => hidden_size), [activation];
FastGRNN(input_size => hidden_size, [activation];
return_state = false, kwargs...)
[Fast recurrent neural network](https://arxiv.org/abs/1901.02358).
Expand Down
4 changes: 2 additions & 2 deletions src/cells/indrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#https://arxiv.org/pdf/1803.04831

@doc raw"""
IndRNNCell((input_size => hidden_size), σ=relu;
IndRNNCell(input_size => hidden_size, σ=relu;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -72,7 +72,7 @@ function Base.show(io::IO, indrnn::IndRNNCell)
end

@doc raw"""
IndRNN((input_size, hidden_size), σ = tanh;
IndRNN(input_size, hidden_size, σ = tanh;
return_state = false, kwargs...)
[Independently recurrent network](https://arxiv.org/pdf/1803.04831).
Expand Down
4 changes: 2 additions & 2 deletions src/cells/lightru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#https://www.mdpi.com/2079-9292/13/16/3204

@doc raw"""
LightRUCell((input_size => hidden_size);
LightRUCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -79,7 +79,7 @@ function Base.show(io::IO, lightru::LightRUCell)
end

@doc raw"""
LightRU((input_size => hidden_size);
LightRU(input_size => hidden_size;
return_state = false, kwargs...)
[Light recurrent unit network](https://www.mdpi.com/2079-9292/13/16/3204).
Expand Down
4 changes: 2 additions & 2 deletions src/cells/ligru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://arxiv.org/pdf/1803.10225
@doc raw"""
LiGRUCell((input_size => hidden_size);
LiGRUCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -79,7 +79,7 @@ function Base.show(io::IO, ligru::LiGRUCell)
end

@doc raw"""
LiGRU((input_size => hidden_size);
LiGRU(input_size => hidden_size;
return_state = false, kwargs...)
[Light gated recurrent network](https://arxiv.org/pdf/1803.10225).
Expand Down
4 changes: 2 additions & 2 deletions src/cells/mgu_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://arxiv.org/pdf/1603.09420
@doc raw"""
MGUCell((input_size => hidden_size);
MGUCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -77,7 +77,7 @@ function Base.show(io::IO, mgu::MGUCell)
end

@doc raw"""
MGU((input_size => hidden_size);
MGU(input_size => hidden_size;
return_state = false, kwargs...)
[Minimal gated unit network](https://arxiv.org/pdf/1603.09420).
Expand Down
20 changes: 14 additions & 6 deletions src/cells/mut_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://proceedings.mlr.press/v37/jozefowicz15.pdf
@doc raw"""
MUT1Cell((input_size => hidden_size);
MUT1Cell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -81,7 +81,9 @@ function Base.show(io::IO, mut::MUT1Cell)
end

@doc raw"""
MUT1((input_size => hidden_size); kwargs...)
MUT1(input_size => hidden_size;
return_state=false,
kwargs...)
[Mutated unit 1 network](https://proceedings.mlr.press/v37/jozefowicz15.pdf).
See [`MUT1Cell`](@ref) for a layer that processes a single sequence.
Expand All @@ -92,6 +94,7 @@ See [`MUT1Cell`](@ref) for a layer that processes a single sequence.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
- `return_state`: Option to return the last state together with the output. Default is `false`.
# Equations
```math
Expand Down Expand Up @@ -145,7 +148,7 @@ function Base.show(io::IO, mut::MUT1)
end

@doc raw"""
MUT2Cell((input_size => hidden_size);
MUT2Cell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -225,7 +228,9 @@ function Base.show(io::IO, mut::MUT2Cell)
end

@doc raw"""
MUT2Cell((input_size => hidden_size); kwargs...)
MUT2Cell(input_size => hidden_size;
return_state=false,
kwargs...)
[Mutated unit 2 network](https://proceedings.mlr.press/v37/jozefowicz15.pdf).
See [`MUT2Cell`](@ref) for a layer that processes a single sequence.
Expand All @@ -236,6 +241,7 @@ See [`MUT2Cell`](@ref) for a layer that processes a single sequence.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
- `return_state`: Option to return the last state together with the output. Default is `false`.
# Equations
```math
Expand Down Expand Up @@ -289,7 +295,7 @@ function Base.show(io::IO, mut::MUT2)
end

@doc raw"""
MUT3Cell((input_size => hidden_size);
MUT3Cell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -368,7 +374,8 @@ function Base.show(io::IO, mut::MUT3Cell)
end

@doc raw"""
MUT3((input_size => hidden_size); kwargs...)
MUT3(input_size => hidden_size;
return_state = false, kwargs...)
[Mutated unit 3 network](https://proceedings.mlr.press/v37/jozefowicz15.pdf).
See [`MUT3Cell`](@ref) for a layer that processes a single sequence.
Expand All @@ -379,6 +386,7 @@ See [`MUT3Cell`](@ref) for a layer that processes a single sequence.
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
- `bias`: include a bias or not. Default is `true`
- `return_state`: Option to return the last state together with the output. Default is `false`.
# Equations
```math
Expand Down
7 changes: 5 additions & 2 deletions src/cells/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# limitations under the License.

@doc raw"""
NASCell((input_size => hidden_size);
NASCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -145,14 +145,17 @@ function Base.show(io::IO, nas::NASCell)
end

@doc raw"""
NAS((input_size => hidden_size)::Pair; kwargs...)
NAS(input_size => hidden_size;
return_state = false,
kwargs...)
[Neural Architecture Search unit](https://arxiv.org/pdf/1611.01578).
See [`NASCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `input_size => hidden_size`: input and inner dimension of the layer
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
Expand Down
7 changes: 5 additions & 2 deletions src/cells/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf
@doc raw"""
PeepholeLSTMCell((input_size => hidden_size);
PeepholeLSTMCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -78,13 +78,16 @@ function Base.show(io::IO, lstm::PeepholeLSTMCell)
end

@doc raw"""
PeepholeLSTM((input_size => hidden_size); kwargs...)
PeepholeLSTM(input_size => hidden_size;
return_state=false,
kwargs...)
[Peephole long short term memory network](https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf).
See [`PeepholeLSTMCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `input_size => hidden_size`: input and inner dimension of the layer
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
Expand Down
6 changes: 4 additions & 2 deletions src/cells/ran_cell.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#https://arxiv.org/pdf/1705.07393
@doc raw"""
RANCell((input_size => hidden_size)::Pair;
RANCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true)
Expand Down Expand Up @@ -80,13 +80,15 @@ function Base.show(io::IO, ran::RANCell)
end

@doc raw"""
RAN(input_size => hidden_size; kwargs...)
RAN(input_size => hidden_size;
return_state = false, kwargs...)
[Recurrent Additive Network cell](https://arxiv.org/pdf/1705.07393).
See [`RANCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `input_size => hidden_size`: input and inner dimension of the layer
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
Expand Down
11 changes: 7 additions & 4 deletions src/cells/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#https://github.com/jzilly/RecurrentHighwayNetworks/blob/master/rhn.py#L138C1-L180C60

"""
RHNCellUnit((input_size => hidden_size)::Pair;
RHNCellUnit(input_size => hidden_size;
init_kernel = glorot_uniform,
bias = true)
"""
Expand Down Expand Up @@ -44,8 +44,8 @@ function Base.show(io::IO, rhn::RHNCellUnit)
end

@doc raw"""
RHNCell((input_size => hidden_size), depth=3;
couple_carry::Bool = true,
RHNCell(input_size => hidden_size, [depth];
couple_carry = true,
cell_kwargs...)
[Recurrent highway network](https://arxiv.org/pdf/1607.03474).
Expand Down Expand Up @@ -139,14 +139,17 @@ end

# TODO fix implementation here
@doc raw"""
RHN((input_size => hidden_size), depth=3; kwargs...)
RHN(input_size => hidden_size, [depth];
return_state = false,
kwargs...)
[Recurrent highway network](https://arxiv.org/pdf/1607.03474).
See [`RHNCellUnit`](@ref) for a the unit component of this layer.
See [`RHNCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `input_size => hidden_size`: input and inner dimension of the layer
- `depth`: depth of the recurrence. Default is 3
- `couple_carry`: couples the carry gate and the transform gate. Default `true`
Expand Down
8 changes: 5 additions & 3 deletions src/cells/scrn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#https://arxiv.org/pdf/1412.7753

@doc raw"""
SCRNCell((input_size => hidden_size);
SCRNCell(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
Expand Down Expand Up @@ -87,17 +87,19 @@ function Base.show(io::IO, scrn::SCRNCell)
end

@doc raw"""
SCRN((input_size => hidden_size);
SCRN(input_size => hidden_size;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
alpha = 0.0)
alpha = 0.0,
return_state = false)
[Structurally contraint recurrent unit](https://arxiv.org/pdf/1412.7753).
See [`SCRNCell`](@ref) for a layer that processes a single sequence.
# Arguments
- `return_state`: Option to return the last state together with the output. Default is `false`.
- `input_size => hidden_size`: input and inner dimension of the layer
- `init_kernel`: initializer for the input to hidden weights
- `init_recurrent_kernel`: initializer for the hidden to hidden weights
Expand Down
16 changes: 8 additions & 8 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
@assert typeof(state)==typeof(initialstates(rlayer)) """\n
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
return first(scan(rlayer.cell, inp, state))
end

function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
@assert typeof(state)==typeof(initialstates(rlayer)) """\n
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
return scan(rlayer.cell, inp, state)
end

0 comments on commit 9194d61

Please sign in to comment.