From 832bf8e4451cad2cac0f8460694370028b04a7ad Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 9 Jan 2025 11:52:52 +0100 Subject: [PATCH] adding return_state to all recurrent layers --- src/cells/fastrnn_cell.jl | 6 ++---- src/cells/indrnn_cell.jl | 3 +-- src/cells/lightru_cell.jl | 3 +-- src/cells/ligru_cell.jl | 3 +-- src/cells/mgu_cell.jl | 3 +-- src/cells/mut_cell.jl | 25 +++++++++++++++++-------- src/cells/nas_cell.jl | 9 ++++++--- src/cells/peepholelstm_cell.jl | 9 ++++++--- src/cells/ran_cell.jl | 9 ++++++--- src/cells/scrn_cell.jl | 9 ++++++--- src/generics.jl | 13 +++++++++---- 11 files changed, 56 insertions(+), 36 deletions(-) diff --git a/src/cells/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl index 4555dea..3cd1f41 100644 --- a/src/cells/fastrnn_cell.jl +++ b/src/cells/fastrnn_cell.jl @@ -134,8 +134,7 @@ end @layer :noexpand FastRNN function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = FastRNNCell(input_size => hidden_size, activation; kwargs...) return FastRNN{return_state, typeof(cell)}(cell) end @@ -290,8 +289,7 @@ end @layer :noexpand FastGRNN function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) return FastGRNN{return_state, typeof(cell)}(cell) end diff --git a/src/cells/indrnn_cell.jl b/src/cells/indrnn_cell.jl index ee38562..9bedebf 100644 --- a/src/cells/indrnn_cell.jl +++ b/src/cells/indrnn_cell.jl @@ -117,8 +117,7 @@ end @layer :noexpand IndRNN function IndRNN((input_size, hidden_size)::Pair, σ = tanh; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = IndRNNCell(input_size => hidden_size, σ; kwargs...) return IndRNN{return_state, typeof(cell)}(cell) end diff --git a/src/cells/lightru_cell.jl b/src/cells/lightru_cell.jl index d10e626..ddb3f8a 100644 --- a/src/cells/lightru_cell.jl +++ b/src/cells/lightru_cell.jl @@ -128,8 +128,7 @@ end @layer :noexpand LightRU function LightRU((input_size, hidden_size)::Pair; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = LightRUCell(input_size => hidden_size; kwargs...) return LightRU{return_state, typeof(cell)}(cell) end diff --git a/src/cells/ligru_cell.jl b/src/cells/ligru_cell.jl index bd54370..1b5619e 100644 --- a/src/cells/ligru_cell.jl +++ b/src/cells/ligru_cell.jl @@ -131,8 +131,7 @@ end @layer :noexpand LiGRU function LiGRU((input_size, hidden_size)::Pair; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = LiGRUCell(input_size => hidden_size; kwargs...) return LiGRU{return_state, typeof(cell)}(cell) end diff --git a/src/cells/mgu_cell.jl b/src/cells/mgu_cell.jl index 3a78c00..bb7c115 100644 --- a/src/cells/mgu_cell.jl +++ b/src/cells/mgu_cell.jl @@ -127,8 +127,7 @@ end @layer :noexpand MGU function MGU((input_size, hidden_size)::Pair; - return_state = false, - kwargs...) + return_state::Bool = false, kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) return MGU{return_state, typeof(cell)}(cell) end diff --git a/src/cells/mut_cell.jl b/src/cells/mut_cell.jl index 37a39c8..be5c226 100644 --- a/src/cells/mut_cell.jl +++ b/src/cells/mut_cell.jl @@ -120,16 +120,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + \tanh(W_h x_t) + b_h) \odot z \\ ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct MUT1{M} <: AbstractRecurrentLayer +struct MUT1{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand MUT1 -function MUT1((input_size, hidden_size)::Pair; kwargs...) +function MUT1((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = MUT1Cell(input_size => hidden_size; kwargs...) - return MUT1(cell) + return MUT1{return_state, typeof(cell)}(cell) end function Base.show(io::IO, mut::MUT1) @@ -258,16 +261,19 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct MUT2{M} <: AbstractRecurrentLayer +struct MUT2{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand MUT2 -function MUT2((input_size, hidden_size)::Pair; kwargs...) +function MUT2((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) - return MUT2(cell) + return MUT2{return_state, typeof(cell)}(cell) end function Base.show(io::IO, mut::MUT2) @@ -395,6 +401,8 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ struct MUT3{M} <: AbstractRecurrentLayer cell::M @@ -402,9 +410,10 @@ end @layer :noexpand MUT3 -function MUT3((input_size, hidden_size)::Pair; kwargs...) +function MUT3((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) - return MUT3(cell) + return MUT3{return_state, typeof(cell)}(cell) end function Base.show(io::IO, mut::MUT3) diff --git a/src/cells/nas_cell.jl b/src/cells/nas_cell.jl index a3fc819..042f94e 100644 --- a/src/cells/nas_cell.jl +++ b/src/cells/nas_cell.jl @@ -203,16 +203,19 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct NAS{M} <: AbstractRecurrentLayer +struct NAS{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand NAS -function NAS((input_size, hidden_size)::Pair; kwargs...) +function NAS((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = NASCell(input_size => hidden_size; kwargs...) - return NAS(cell) + return NAS{return_state, typeof(cell)}(cell) end function Base.show(io::IO, nas::NAS) diff --git a/src/cells/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl index bbc8862..aeabea9 100644 --- a/src/cells/peepholelstm_cell.jl +++ b/src/cells/peepholelstm_cell.jl @@ -120,16 +120,19 @@ h_t &= o_t \odot \sigma_h(c_t). ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct PeepholeLSTM{M} <: AbstractRecurrentLayer +struct PeepholeLSTM{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand PeepholeLSTM -function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) +function PeepholeLSTM((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = PeepholeLSTMCell(input_size => hidden_size; kwargs...) - return PeepholeLSTM(cell) + return PeepholeLSTM{return_state, typeof(cell)}(cell) end function Base.show(io::IO, peepholelstm::PeepholeLSTM) diff --git a/src/cells/ran_cell.jl b/src/cells/ran_cell.jl index af82f51..6355851 100644 --- a/src/cells/ran_cell.jl +++ b/src/cells/ran_cell.jl @@ -129,16 +129,19 @@ h_t &= g(c_t) ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct RAN{M} <: AbstractRecurrentLayer +struct RAN{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand RAN -function RAN((input_size, hidden_size)::Pair; kwargs...) +function RAN((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = RANCell(input_size => hidden_size; kwargs...) - return RAN(cell) + return RAN{return_state, typeof(cell)}(cell) end function Base.show(io::IO, ran::RAN) diff --git a/src/cells/scrn_cell.jl b/src/cells/scrn_cell.jl index af1aa68..9db88cc 100644 --- a/src/cells/scrn_cell.jl +++ b/src/cells/scrn_cell.jl @@ -131,16 +131,19 @@ y_t &= f(U_y h_t + W_y s_t) ## Returns - New hidden states `new_states` as an array of size `hidden_size x len x batch_size`. + When `return_state = true` it returns a tuple of the hidden stats `new_states` and + the last state of the iteration. """ -struct SCRN{M} <: AbstractRecurrentLayer +struct SCRN{S,M} <: AbstractRecurrentLayer cell::M end @layer :noexpand SCRN -function SCRN((input_size, hidden_size)::Pair; kwargs...) +function SCRN((input_size, hidden_size)::Pair; + return_state::Bool = false, kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...) - return SCRN(cell) + return SCRN{return_state, typeof(cell)}(cell) end function Base.show(io::IO, scrn::SCRN) diff --git a/src/generics.jl b/src/generics.jl index ab0cb3a..01de690 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -16,7 +16,7 @@ function (rcell::AbstractRecurrentCell)(inp::AbstractVecOrMat) return rcell(inp, state) end -abstract type AbstractRecurrentLayer end +abstract type AbstractRecurrentLayer{S} end function initialstates(rlayer::AbstractRecurrentLayer) return initialstates(rlayer.cell) @@ -27,9 +27,14 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) return rlayer(inp, state) end -function (rlayer::AbstractRecurrentLayer)( - inp::AbstractArray, - state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) +function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray, + state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) + @assert ndims(inp) == 2 || ndims(inp) == 3 + return first(scan(rlayer.cell, inp, state)) +end + +function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray, + state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 return scan(rlayer.cell, inp, state) end