Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
[#106] taking train_step out in rnns
Browse files Browse the repository at this point in the history
(just add losses together for multitask)
  • Loading branch information
jsadler2 committed Jun 4, 2021
1 parent e939460 commit 17bfab8
Showing 1 changed file with 22 additions and 145 deletions.
167 changes: 22 additions & 145 deletions river_dl/rnns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from __future__ import print_function, division
import tensorflow as tf
from tensorflow.keras import layers
from river_dl.loss_functions import nnse_masked_one_var, nnse_one_var_samplewise


class SingletaskLSTMModel(tf.keras.Model):
class LSTMModel(tf.keras.Model):
def __init__(
self,
hidden_size,
num_tasks=1,
recurrent_dropout=0,
dropout=0,
):
Expand All @@ -18,6 +18,7 @@ def __init__(
:param dropout: [float] value between 0 and 1 for the probability of an input element to be zero
"""
super().__init__()
self.num_tasks = num_tasks,
self.rnn_layer = layers.LSTM(
hidden_size,
return_sequences=True,
Expand All @@ -27,163 +28,39 @@ def __init__(
dropout=dropout
)
self.dense_main = layers.Dense(1, name="dense_main")
if self.num_tasks == 2:
self.dense_aux = layers.Dense(1, name="dense_aux")
self.h = None
self.c = None

@tf.function
def call(self, inputs, **kwargs):
x, self.h, self.c = self.rnn_layer(inputs)
main_prediction = self.dense_main(x)
return main_prediction


class MultitaskLSTMModel(tf.keras.Model):
def __init__(
self,
hidden_size,
gradient_correction=False,
lambdas=(1, 1),
recurrent_dropout=0,
dropout=0,
grad_log_file=None,
):
"""
:param hidden_size: [int] the number of hidden units
:param gradient_correction: [bool]
:param lambdas: [array-like] weights to multiply the loss from each target
variable by
:param recurrent_dropout: [float] value between 0 and 1 for the probability of a reccurent element to be zero
:param dropout: [float] value between 0 and 1 for the probability of an input element to be zero
:param grad_log_file: [str] location of gradient log file
"""
super().__init__()
self.gradient_correction = gradient_correction
self.grad_log_file = grad_log_file
self.lambdas = lambdas
self.rnn_layer = layers.LSTM(
hidden_size,
return_sequences=True,
stateful=True,
return_state=True,
name="rnn_shared",
recurrent_dropout=recurrent_dropout,
dropout=dropout
)
self.dense_main = layers.Dense(1, name="dense_main")
self.dense_aux = layers.Dense(1, name="dense_aux")
self.h = None
self.c = None

@tf.function
def call(self, inputs, **kwargs):
x, self.h, self.c = self.rnn_layer(inputs)
main_prediction = self.dense_main(x)
aux_prediction = self.dense_aux(x)
return tf.concat([main_prediction, aux_prediction], axis=2)

@tf.function
def train_step(self, data):
x, y = data

# If I don't do one forward pass before starting the gradient tape,
# the thing hangs
_ = self(x)
with tf.GradientTape(persistent=True) as tape:
y_pred = self(x, training=True) # forward pass

loss_main = nnse_one_var_samplewise(y, y_pred, 0, self.tasks)
loss_aux = nnse_one_var_samplewise(y, y_pred, 1, self.tasks)

trainable_vars = self.trainable_variables

main_out_vars = get_variables(trainable_vars, "dense_main")
aux_out_vars = get_variables(trainable_vars, "dense_aux")
shared_vars = get_variables(trainable_vars, "rnn_shared")

# get gradients
gradient_main_out = tape.gradient(loss_main, main_out_vars)
gradient_aux_out = tape.gradient(loss_aux, aux_out_vars)
gradient_shared_main = tape.gradient(loss_main, shared_vars)
gradient_shared_aux = tape.gradient(loss_aux, shared_vars)

if self.gradient_correction:
# adjust auxiliary gradient
gradient_shared_aux = adjust_gradient_list(
gradient_shared_main, gradient_shared_aux, self.grad_log_file
)
combined_gradient = combine_gradients_list(
gradient_shared_main, gradient_shared_aux, lambdas=self.lambdas
)

# apply gradients
self.optimizer.apply_gradients(zip(gradient_main_out, main_out_vars))
self.optimizer.apply_gradients(zip(gradient_aux_out, aux_out_vars))
self.optimizer.apply_gradients(zip(combined_gradient, shared_vars))
return {"loss_main": loss_main, "loss_aux": loss_aux}


class SingletaskGRUModel(SingletaskLSTMModel):
if self.num_tasks == 1:
main_prediction = self.dense_main(x)
return main_prediction
elif self.num_tasks == 2:
main_prediction = self.dense_main(x)
aux_prediction = self.dense_aux(x)
return tf.concat([main_prediction, aux_prediction], axis=2)
else:
raise ValueError(f'This model only supports 1 or 2 tasks (not {self.num_tasks})')


class GRUModel(LSTMModel):
def __init__(
self,
hidden_size,
num_tasks=1,
dropout=0,
recurrent_dropout=0,
):
"""
:param hidden_size: [int] the number of hidden units
"""
super().__init__(hidden_size)
self.rnn_layer = layers.GRU(
hidden_size, return_sequences=True, name="rnn_shared"
)


class MultitaskGRUModel(MultitaskLSTMModel):
def __init__(
self,
hidden_size,
lambdas=(1, 1)
):
"""
:param hidden_size: [int] the number of hidden units
"""
super().__init__(hidden_size, lambdas=lambdas)
super().__init__(hidden_size, num_tasks=num_tasks)
self.rnn_layer = layers.GRU(
hidden_size, return_sequences=True, name="rnn_shared"
hidden_size, recurrent_dropout=recurrent_dropout, dropout=dropout, return_sequences=True, name="rnn_shared"
)


def adjust_gradient(main_grad, aux_grad, logfile=None):
# flatten tensors
main_grad_flat = tf.reshape(main_grad, [-1])
aux_grad_flat = tf.reshape(aux_grad, [-1])

# project and adjust
projection = (
tf.minimum(tf.reduce_sum(main_grad_flat * aux_grad_flat), 0)
* main_grad_flat
/ tf.reduce_sum(main_grad_flat * main_grad_flat)
)
if logfile:
logfile = "file://" + logfile
tf.print(tf.reduce_sum(projection), output_stream=logfile, sep=",")
projection = tf.cond(
tf.math.is_nan(tf.reduce_sum(projection)),
lambda: tf.zeros(aux_grad_flat.shape),
lambda: projection,
)
adjusted = aux_grad_flat - projection
return tf.reshape(adjusted, aux_grad.shape)


def get_variables(trainable_variables, name):
return [v for v in trainable_variables if name in v.name]


def combine_gradients_list(main_grads, aux_grads, lambdas=(1, 1)):
return [lambdas[0] * main_grads[i] + lambdas[1] * aux_grads[i] for i in range(len(main_grads))]


def adjust_gradient_list(main_grads, aux_grads, logfile=None):
return [
adjust_gradient(main_grads[i], aux_grads[i], logfile)
for i in range(len(main_grads))
]

0 comments on commit 17bfab8

Please sign in to comment.