-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent cells cannot be chained with other layers #1155
Comments
Flux's RNNCell design was deeply problematic and their latest iteration (v0.15 and v0.16) is inspired by the current design in Lux. You need to wrap the Cell in a https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.StatefulRecurrentCell or https://lux.csail.mit.edu/stable/api/Lux/layers#Lux.Recurrence based on your usecase. |
Thanks. But wrapping the cell in |
Here is one https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN. Generally I recommend writing out the model using |
Thanks. |
The recurrent cells (
RNNCell
,LSTMCell
, etc) output the following:(out_dims, batch_size)
The first element is a tuple instead of the output array (the hidden state), which was the case as it was in Flux.jl. Indeed, in Flux one was able to build a standard RNN architecture with output just with
Chain(RNNCell(Ny => Nh), Dense(Nh => Ny))
.This seems not possible in Lux because the first output is not directly an array anymore, an one gets the error
Why was this changed from Flux? Can it be reverted?
The text was updated successfully, but these errors were encountered: