Skip to content

Commit

Permalink
adding return_state to all recurrent layers
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 9, 2025
1 parent 01d46ac commit 832bf8e
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 36 deletions.
6 changes: 2 additions & 4 deletions src/cells/fastrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ end
@layer :noexpand FastRNN

function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = FastRNNCell(input_size => hidden_size, activation; kwargs...)
return FastRNN{return_state, typeof(cell)}(cell)
end
Expand Down Expand Up @@ -290,8 +289,7 @@ end
@layer :noexpand FastGRNN

function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...)
return FastGRNN{return_state, typeof(cell)}(cell)
end
Expand Down
3 changes: 1 addition & 2 deletions src/cells/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ end
@layer :noexpand IndRNN

function IndRNN((input_size, hidden_size)::Pair, σ = tanh;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = IndRNNCell(input_size => hidden_size, σ; kwargs...)
return IndRNN{return_state, typeof(cell)}(cell)
end
Expand Down
3 changes: 1 addition & 2 deletions src/cells/lightru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ end
@layer :noexpand LightRU

function LightRU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = LightRUCell(input_size => hidden_size; kwargs...)
return LightRU{return_state, typeof(cell)}(cell)
end
Expand Down
3 changes: 1 addition & 2 deletions src/cells/ligru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ end
@layer :noexpand LiGRU

function LiGRU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = LiGRUCell(input_size => hidden_size; kwargs...)
return LiGRU{return_state, typeof(cell)}(cell)
end
Expand Down
3 changes: 1 addition & 2 deletions src/cells/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ end
@layer :noexpand MGU

function MGU((input_size, hidden_size)::Pair;
return_state = false,
kwargs...)
return_state::Bool = false, kwargs...)
cell = MGUCell(input_size => hidden_size; kwargs...)
return MGU{return_state, typeof(cell)}(cell)
end
Expand Down
25 changes: 17 additions & 8 deletions src/cells/mut_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + \tanh(W_h x_t) + b_h) \odot z \\
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct MUT1{M} <: AbstractRecurrentLayer
struct MUT1{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand MUT1

function MUT1((input_size, hidden_size)::Pair; kwargs...)
function MUT1((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = MUT1Cell(input_size => hidden_size; kwargs...)
return MUT1(cell)
return MUT1{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, mut::MUT1)
Expand Down Expand Up @@ -258,16 +261,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct MUT2{M} <: AbstractRecurrentLayer
struct MUT2{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand MUT2

function MUT2((input_size, hidden_size)::Pair; kwargs...)
function MUT2((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = MUT2Cell(input_size => hidden_size; kwargs...)
return MUT2(cell)
return MUT2{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, mut::MUT2)
Expand Down Expand Up @@ -395,16 +401,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct MUT3{M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand MUT3

function MUT3((input_size, hidden_size)::Pair; kwargs...)
function MUT3((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = MUT3Cell(input_size => hidden_size; kwargs...)
return MUT3(cell)
return MUT3{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, mut::MUT3)
Expand Down
9 changes: 6 additions & 3 deletions src/cells/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,19 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5)
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct NAS{M} <: AbstractRecurrentLayer
struct NAS{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand NAS

function NAS((input_size, hidden_size)::Pair; kwargs...)
function NAS((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = NASCell(input_size => hidden_size; kwargs...)
return NAS(cell)
return NAS{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, nas::NAS)
Expand Down
9 changes: 6 additions & 3 deletions src/cells/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,19 @@ h_t &= o_t \odot \sigma_h(c_t).
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct PeepholeLSTM{M} <: AbstractRecurrentLayer
struct PeepholeLSTM{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand PeepholeLSTM

function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...)
function PeepholeLSTM((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = PeepholeLSTMCell(input_size => hidden_size; kwargs...)
return PeepholeLSTM(cell)
return PeepholeLSTM{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, peepholelstm::PeepholeLSTM)
Expand Down
9 changes: 6 additions & 3 deletions src/cells/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,19 @@ h_t &= g(c_t)
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct RAN{M} <: AbstractRecurrentLayer
struct RAN{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand RAN

function RAN((input_size, hidden_size)::Pair; kwargs...)
function RAN((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = RANCell(input_size => hidden_size; kwargs...)
return RAN(cell)
return RAN{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, ran::RAN)
Expand Down
9 changes: 6 additions & 3 deletions src/cells/scrn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,19 @@ y_t &= f(U_y h_t + W_y s_t)
## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct SCRN{M} <: AbstractRecurrentLayer
struct SCRN{S,M} <: AbstractRecurrentLayer
cell::M
end

@layer :noexpand SCRN

function SCRN((input_size, hidden_size)::Pair; kwargs...)
function SCRN((input_size, hidden_size)::Pair;
return_state::Bool = false, kwargs...)
cell = SCRNCell(input_size => hidden_size; kwargs...)
return SCRN(cell)
return SCRN{return_state, typeof(cell)}(cell)
end

function Base.show(io::IO, scrn::SCRN)
Expand Down
13 changes: 9 additions & 4 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function (rcell::AbstractRecurrentCell)(inp::AbstractVecOrMat)
return rcell(inp, state)
end

abstract type AbstractRecurrentLayer end
abstract type AbstractRecurrentLayer{S} end

function initialstates(rlayer::AbstractRecurrentLayer)
return initialstates(rlayer.cell)
Expand All @@ -27,9 +27,14 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
return rlayer(inp, state)
end

function (rlayer::AbstractRecurrentLayer)(
inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
return first(scan(rlayer.cell, inp, state))
end

function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(rlayer.cell, inp, state)
end

0 comments on commit 832bf8e

Please sign in to comment.