Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature requests for RNNs #2514

Open
3 of 9 tasks
CarloLucibello opened this issue Nov 4, 2024 · 12 comments
Open
3 of 9 tasks

feature requests for RNNs #2514

CarloLucibello opened this issue Nov 4, 2024 · 12 comments
Labels

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Nov 4, 2024

After the redesign in #2500, here is a list of potential improvements for recurrent layers and recurrent cells

@MartinuzziFrancesco
Copy link
Contributor

Could it be possible to add to the list the option to use different initializers for the input matrix and recurrent matrix? This is provided by both Keras/TF and Flax.

This should be as straightforward as

function RNNCell((in, out)::Pair, σ=relu;
    kernel_init = glorot_uniform,
    recurrent_kernel_init = glorot_uniform,
    bias = true)
    Wi = kernel_init(out, in)
    U = recurrent_kernel_init(out, out)
    b = create_bias(Wi, bias, size(Wi, 1))
    return RNNCell(σ, Wi, U, b)
end

I can also open a quick PR on this if needed

@CarloLucibello
Copy link
Member Author

yes! PR welcome

@MartinuzziFrancesco
Copy link
Contributor

Following up on this, should we also have an option to choose the init for the bias?

@CarloLucibello
Copy link
Member Author

We don't do it for feedforward layers, if someone wants a non-zero bias can just change it manually in the constructor, layer.bias .= ...

@MartinuzziFrancesco
Copy link
Contributor

Ehi, a couple of questions on this features request again. Would the initialstates function be like the current function (rnn::RNN)(x::AbstractVecOrMat) but returning just the state?

would something simple like

function initialstates(rnn::RNNCell; init_state = zeros)
    state = init_state(size(rnn.Wh, 2))
    return state
end

function initialstates(lstm::LSTMCell; init_state = zeros, init_cstate = zeros)
    state = init_state(size(lstm.Wh, 2))
    cstate = init_cstate(size(lstm.Wh, 2))
    return state, cstate
end

suffice or were you looking for something more? Maybe more control on the type would be needed

@CarloLucibello
Copy link
Member Author

I would just have

initialstates(c::RNNCell) = zeros_like(c.Wh, size(c.Wh, 2)))
initialstates(c:: LSTMCell) = zeros_like(c.Wh, size(c.Wh, 2)), zeros_like(c.Wh, size(c.Wh, 2))

If different initializations are needed, we could add an init_state to the constructor, but maybe better let the user handle it as a part of the model, for simplicity and flexibility (e.g. making the initial state trainable). I don't have a strong opinion though.

@MartinuzziFrancesco
Copy link
Contributor

so this way we would simply do

function (rnn::RNNCell)(inp::AbstractVecOrMat)
    state = initialstates(rnn)
    return rnn(inp, state)
end

to keep compatibility for the current version, right?

I think your point is good, additionally no other library provides a specific init for the state so that's probably a little overkill. I'll push something along these lines later

@MartinuzziFrancesco
Copy link
Contributor

Trying to tackle adding num_layers and dropout, with Bidirectional as well once I figured out the general direction. I just wanted to ask if the current approach is in line with what you had in mind:

struct TestRNN{A, B}
    cells::A
    dropout_layer::B
end

Flux.@layer :expand TestRNN

function TestRNN((in_size, out_size)::Pair;
    n_layers::Int=1,
    dropout::Float64=0.0,
    kwargs...)
    cells = []

    for i in 1:n_layers
        tin_size = i == 1 ? in_size : out_size
        push!(cells, RNNCell(tin_size => out_size; kwargs...))
    end

    if dropout > 0.0
        dropout_layer = Dropout(dropout)
    else
        dropout_layer = nothing
    end

    return TestRNN(cells, dropout_layer)
end

function (rnn::TestRNN)(inp, state)
    @assert ndims(inp) == 2 || ndims(inp) == 3
    output = []
    num_layers = length(rnn.cells)

    for inp_t in eachslice(inp, dims=2)
        new_states = []
        for (idx_cell, cell) in enumerate(rnn.cells)
            new_state = cell(inp_t, state[:, idx_cell])
            new_states = vcat(new_states, [new_state])
            inp_t = new_state

            if rnn.dropout_layer isa Dropout && idx_cell < num_layers - 1
                inp_t = rnn.dropout_layer(inp_t)
            end
        end
        state = stack(new_states)
        output = vcat(output, [inp_t])
    end
    output = stack(output, dims=2)
    return output, state
end

@CarloLucibello
Copy link
Member Author

I think we don't need this additional complexity. A simple stacked rnn can be constructed as a chain:

stacked_rnn = Chain(LSTM(3 => 3), Dropout(0.5), LSTM(3 => 3))

If control of the initial states is also needed, it is not hard to define a custom struct the job:

struct StackedRNN{L,S}
    layers::L
    states0::S
end

function StackedRNN(d, num_layers)
    layers = [LSTM(d => d) for _ in num_layers]
    states0 = [Flux.initialstates(l) for l in layers]
    return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
   for (layer, state0) in zip(rnn.layers, rnn.states0)
       x = layer(x, state0) 
   end
   return x
end

I think it is enough to document this in the guide
https://github.com/FluxML/Flux.jl/blob/master/docs/src/guide/models/recurrence.md

This was referenced Dec 13, 2024
@MartinuzziFrancesco
Copy link
Contributor

I assume

add an option in the RNN/GRU/LSTM constructor to return the last state, along the output

should be tackled at the scan level, maybe something like

function scan(cell, x, state; return_state = false)
  y = []
  for x_t in eachslice(x, dims = 2)
    yt, state = cell(x_t, state)
    y = vcat(y, [yt])
  end
  if !(return_state)
    return stack(y, dims = 2)
  else
    return stack(y, dims = 2), state
end

and then

struct RNN{M}
  cell::M
  return_state::Bool
end

@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
  cell = RNNCell(in => out, σ; cell_kwargs...)
  return RNN(cell, return_state)
end

(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
  @assert ndims(x) == 2 || ndims(x) == 3
  # [x] = [in, L] or [in, L, B]
  # [h] = [out] or [out, B]
  return scan(m.cell, x, h; return_state = m.return_state)
end

is this the direction you wanted to take this in?

@CarloLucibello
Copy link
Member Author

Yes, it could be something like that, although In order to preserve type stability it would be better to embed the return_state information as a parameter for the type (it could be a symbol or a bool).

struct RNN{S, M}
  cell::M
end

function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
  cell = RNNCell(in => out, σ; cell_kwargs...)
  if return_state
      return RNN{:return_state}(cell)
  else
     return RNN{:no_return_state}(cell)
  end
end

function scan(cell, x, state; return_state = false)
  y = []
  for x_t in eachslice(x, dims = 2)
    yt, state = cell(x_t, state)
    y = vcat(y, [yt])
  end
  return stack(y, dims = 2), state
end

function (m::RNN{:no_return_state)(x::AbstractArray, h)
  return scan(m.cell, x, h; return_state = m.return_state)[1]
end
function (m::RNN{:return_state)(x::AbstractArray, h)
  return scan(m.cell, x, h; return_state = m.return_state)
end

@MartinuzziFrancesco
Copy link
Contributor

makes sense! this way I also don't think we need the return_state in scan, so building on your code would something like this work?

struct RNN{S, M}
  cell::M
end

function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...)
  cell = RNNCell(in => out, σ; cell_kwargs...)
     return RNN{return_state, typeof(cell)}(cell)
end

function scan(cell, x, state) # removing  return_state = false)
  y = []
  for x_t in eachslice(x, dims = 2)
    yt, state = cell(x_t, state)
    y = vcat(y, [yt])
  end
  return stack(y, dims = 2), state
end

function (rnn::RNN{false})(inp::AbstractArray, state)
  return first(scan(rnn.cell, inp, state))
end
function (rnn::RNN{true})(inp::AbstractArray, state)
  return scan(rnn.cell, inp, state)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants