-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature requests for RNNs #2514
Comments
Could it be possible to add to the list the option to use different initializers for the input matrix and recurrent matrix? This is provided by both Keras/TF and Flax. This should be as straightforward as function RNNCell((in, out)::Pair, σ=relu;
kernel_init = glorot_uniform,
recurrent_kernel_init = glorot_uniform,
bias = true)
Wi = kernel_init(out, in)
U = recurrent_kernel_init(out, out)
b = create_bias(Wi, bias, size(Wi, 1))
return RNNCell(σ, Wi, U, b)
end I can also open a quick PR on this if needed |
yes! PR welcome |
Following up on this, should we also have an option to choose the init for the bias? |
We don't do it for feedforward layers, if someone wants a non-zero bias can just change it manually in the constructor, |
Ehi, a couple of questions on this features request again. Would the would something simple like function initialstates(rnn::RNNCell; init_state = zeros)
state = init_state(size(rnn.Wh, 2))
return state
end
function initialstates(lstm::LSTMCell; init_state = zeros, init_cstate = zeros)
state = init_state(size(lstm.Wh, 2))
cstate = init_cstate(size(lstm.Wh, 2))
return state, cstate
end suffice or were you looking for something more? Maybe more control on the type would be needed |
I would just have
If different initializations are needed, we could add an |
so this way we would simply do function (rnn::RNNCell)(inp::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(inp, state)
end to keep compatibility for the current version, right? I think your point is good, additionally no other library provides a specific |
Trying to tackle adding struct TestRNN{A, B}
cells::A
dropout_layer::B
end
Flux.@layer :expand TestRNN
function TestRNN((in_size, out_size)::Pair;
n_layers::Int=1,
dropout::Float64=0.0,
kwargs...)
cells = []
for i in 1:n_layers
tin_size = i == 1 ? in_size : out_size
push!(cells, RNNCell(tin_size => out_size; kwargs...))
end
if dropout > 0.0
dropout_layer = Dropout(dropout)
else
dropout_layer = nothing
end
return TestRNN(cells, dropout_layer)
end
function (rnn::TestRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
output = []
num_layers = length(rnn.cells)
for inp_t in eachslice(inp, dims=2)
new_states = []
for (idx_cell, cell) in enumerate(rnn.cells)
new_state = cell(inp_t, state[:, idx_cell])
new_states = vcat(new_states, [new_state])
inp_t = new_state
if rnn.dropout_layer isa Dropout && idx_cell < num_layers - 1
inp_t = rnn.dropout_layer(inp_t)
end
end
state = stack(new_states)
output = vcat(output, [inp_t])
end
output = stack(output, dims=2)
return output, state
end
|
I think we don't need this additional complexity. A simple stacked rnn can be constructed as a chain: stacked_rnn = Chain(LSTM(3 => 3), Dropout(0.5), LSTM(3 => 3)) If control of the initial states is also needed, it is not hard to define a custom struct the job: struct StackedRNN{L,S}
layers::L
states0::S
end
function StackedRNN(d, num_layers)
layers = [LSTM(d => d) for _ in num_layers]
states0 = [Flux.initialstates(l) for l in layers]
return StackedRNN(layers, states0)
end
function (m::StackedRNN)(x)
for (layer, state0) in zip(rnn.layers, rnn.states0)
x = layer(x, state0)
end
return x
end I think it is enough to document this in the guide |
I assume
should be tackled at the function scan(cell, x, state; return_state = false)
y = []
for x_t in eachslice(x, dims = 2)
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
if !(return_state)
return stack(y, dims = 2)
else
return stack(y, dims = 2), state
end and then struct RNN{M}
cell::M
return_state::Bool
end
@layer :noexpand RNN
initialstates(rnn::RNN) = initialstates(rnn.cell)
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell, return_state)
end
(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))
function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
return scan(m.cell, x, h; return_state = m.return_state)
end is this the direction you wanted to take this in? |
Yes, it could be something like that, although In order to preserve type stability it would be better to embed the struct RNN{S, M}
cell::M
end
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
if return_state
return RNN{:return_state}(cell)
else
return RNN{:no_return_state}(cell)
end
end
function scan(cell, x, state; return_state = false)
y = []
for x_t in eachslice(x, dims = 2)
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2), state
end
function (m::RNN{:no_return_state)(x::AbstractArray, h)
return scan(m.cell, x, h; return_state = m.return_state)[1]
end
function (m::RNN{:return_state)(x::AbstractArray, h)
return scan(m.cell, x, h; return_state = m.return_state)
end |
makes sense! this way I also don't think we need the struct RNN{S, M}
cell::M
end
function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN{return_state, typeof(cell)}(cell)
end
function scan(cell, x, state) # removing return_state = false)
y = []
for x_t in eachslice(x, dims = 2)
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2), state
end
function (rnn::RNN{false})(inp::AbstractArray, state)
return first(scan(rnn.cell, inp, state))
end
function (rnn::RNN{true})(inp::AbstractArray, state)
return scan(rnn.cell, inp, state)
end |
After the redesign in #2500, here is a list of potential improvements for recurrent layers and recurrent cells
add an option in constructors to have trainable initial statelet's keep it simple (and also follow pytorch) by not having this, it can be part of the outer modelimplement the num_layers argument for stacked RNNs(issue Stacked RNN in Flux.jl? #2452) (see feature requests for RNNs #2514 (comment))add dropout(issue recurrent dropout #1040) (see feature requests for RNNs #2514 (comment))Bidirectional
for RNN layers #1790)initialstates
function. It could be useful in the LSTM case where the initial state is more complicated (two vectors). (done in Adding initialstates function to RNNs #2541)Recur
(but maybe this is confusing) orRecurrence
as in Lux. (Recurrence layer #2549)The text was updated successfully, but these errors were encountered: