Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
can provide h_/c_init to initalize rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Jun 7, 2021
1 parent 1c1641a commit 8362981
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions river_dl/rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
input element to be zero
"""
super().__init__()
self.hidden_size = hidden_size
self.num_tasks = num_tasks
self.rnn_layer = layers.LSTM(
hidden_size,
Expand All @@ -33,6 +34,10 @@ def __init__(

@tf.function
def call(self, inputs, **kwargs):
batch_size = inputs.shape[0]
h_init = kwargs.get("h_init", tf.zeros([batch_size, self.hidden_size]))
c_init = kwargs.get("c_init", tf.zeros([batch_size, self.hidden_size]))
self.rnn_layer.reset_states(states=[h_init, c_init])
x, h, c = self.rnn_layer(inputs)
self.states = h, c
if self.num_tasks == 1:
Expand Down

0 comments on commit 8362981

Please sign in to comment.