Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrent cells cannot be chained with other layers #1155

Closed
bertini97 opened this issue Jan 2, 2025 · 4 comments
Closed

Recurrent cells cannot be chained with other layers #1155

bertini97 opened this issue Jan 2, 2025 · 4 comments

Comments

@bertini97
Copy link

The recurrent cells (RNNCell, LSTMCell, etc) output the following:

  • Tuple containing
    • Output $h_{\text{new}}$ of shape (out_dims, batch_size)
    • Tuple containing new hidden state $h_{\text{new}}$
  • Updated model state

The first element is a tuple instead of the output array (the hidden state), which was the case as it was in Flux.jl. Indeed, in Flux one was able to build a standard RNN architecture with output just with Chain(RNNCell(Ny => Nh), Dense(Nh => Ny)).

This seems not possible in Lux because the first output is not directly an array anymore, an one gets the error

ERROR: MethodError: no method matching (::Dense{…})(::Tuple{…}, ::@NamedTuple{}, ::@NamedTuple{})
The object of type `Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Closest candidates are:
  (::Dense)(::AbstractArray, ::Any, ::NamedTuple)
   @ Lux ~/.julia/packages/Lux/fMnM0/src/layers/basic.jl:339

Why was this changed from Flux? Can it be reverted?

@avik-pal
Copy link
Member

avik-pal commented Jan 2, 2025

Flux's RNNCell design was deeply problematic and their latest iteration (v0.15 and v0.16) is inspired by the current design in Lux.

You need to wrap the Cell in a https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.StatefulRecurrentCell or https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Recurrence based on your usecase.

@bertini97
Copy link
Author

Thanks. But wrapping the cell in Recurrence inside the chain makes it so the input sequence is passed to the cell and processed at once, meaning that the output layer receives a Vector of Vector and doesn't know what to do with it. Is there any documentation on how to implement the simplest RNN in Lux?

@avik-pal
Copy link
Member

avik-pal commented Jan 2, 2025

Is there any documentation on how to implement the simplest RNN in Lux?

Here is one https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN. Generally I recommend writing out the model using @compact instead of using Chain for RNNs

@bertini97
Copy link
Author

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants