Skip to content

Commit

Permalink
[USGS-R#146] simplified training fxn
Browse files Browse the repository at this point in the history
rm trainer class, model as input, files as input
  • Loading branch information
jsadler2 committed Dec 1, 2021
1 parent c1fa37b commit 90e5716
Showing 1 changed file with 81 additions and 234 deletions.
315 changes: 81 additions & 234 deletions river_dl/train.py
Original file line number Diff line number Diff line change
@@ -1,182 +1,50 @@
import os
import random
import numpy as np
from numpy.lib.npyio import NpzFile
import datetime
import tensorflow as tf
from river_dl.RGCN import RGCNModel
from river_dl.loss_functions import weighted_masked_rmse_gw
from river_dl.rnns import LSTMModel, GRUModel


def get_data_if_file(d):
"""
rudimentary check if data .npz file is already loaded. if not, load it
:param d:
:return:
"""
if isinstance(d, NpzFile) or isinstance(d, dict):
return d
else:
return np.load(d, allow_pickle=True)


# This is a training engine that initializes our model and contains routines for pretraining and finetuning
class trainer():
def __init__(self, model, optimizer, loss_fn, weights=None):
self.model = model
self.model.compile(optimizer=optimizer, loss=loss_fn)
if weights:
self.model.load_weights(weights)

def pre_train(self, x, y, epochs, batch_size, out_dir):

## Set up training log callback
csv_log = tf.keras.callbacks.CSVLogger(
os.path.join(out_dir, f"pretrain_log.csv")
)

# Use generic fit statement
self.model.fit(
x=x,
y=y,
epochs=epochs,
batch_size=batch_size,
callbacks=[csv_log],
)
# Save the pretrained weights
self.model.save_weights(os.path.join(out_dir, "pretrained_weights/"))
return self.model

def fine_tune(self, x, y, x_val, y_val, epochs, batch_size, out_dir, early_stop_patience=None, use_cpu = False):
# Specify our training log
csv_log = tf.keras.callbacks.CSVLogger(
os.path.join(out_dir, "finetune_log.csv")
)

# Set up early stopping rounds if desired, setting this to the total number of epochs is the same as not using it
if not early_stop_patience:
early_stop_patience = epochs

early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=early_stop_patience, restore_best_weights=False,
verbose=1)

# Save alternate weight file that saves the best validation weights
best_val = tf.keras.callbacks.ModelCheckpoint(
os.path.join(out_dir, 'best_val_weights/'), monitor='val_loss', verbose=0, save_best_only=True,
save_weights_only=True, mode='min', save_freq='epoch')

# Ensure that training happens on CPU if using the GW loss function
if use_cpu:
with tf.device('/CPU:0'):
self.model.fit(
x=x,
y=y,
epochs=epochs,
batch_size=batch_size,
callbacks=[csv_log, early_stop, best_val],
validation_data=(x_val, y_val)
)
else:
self.model.fit(
x=x,
y=y,
epochs=epochs,
batch_size=batch_size,
callbacks=[csv_log, early_stop, best_val],
validation_data=(x_val, y_val)
)

# Save our trained weights
self.model.save_weights(os.path.join(out_dir, f"trained_weights/"))
return self.model


def train_model(
io_data,
model,
x_trn,
y_trn,
epochs,
hidden_units,
loss_func,
out_dir,
model_type="rgcn",
batch_size,
x_val=None,
y_val=None,
weight_dir=None,
best_val_weight_dir=None,
log_file=None,
time_file=None,
seed=None,
dropout=0,
recurrent_dropout=0,
num_tasks=1,
learning_rate = 0.01,
train_type = 'pre',
early_stop_patience = None,
limit_pretrain = False,
use_cpu=False
):
"""
train the rgcn
:param io_data: [dict or str] input and output data for model
:param model: [compiled TF model] a TF model compiled with a loss function
and an optimizer
:param epochs: [int] number of train epochs
:param hidden_units: [int] number of hidden layers
:param loss_func: [function] loss function that the model will be fit to
:param out_dir: [str] directory where the output files should be written
:param model_type: [str] which model to use (either 'lstm', 'rgcn', or
'gru')
:param weight_dir: [str] path to directory where trained weights will be
saved from the last training epoch
:param best_val_weight_dir: [str] path to directory where trained weights
will be saved from the training epoch with the best validation performance
:param log_file: [str] path to file where training log will be saved
:param time_file: [str] path to file where training time will be written
:param seed: [int] random seed
:param recurrent_dropout: [float] value between 0 and 1 for the probability
of a reccurent element to be zero
:param dropout: [float] value between 0 and 1 for the probability of an
input element to be zero
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:param learning_rate: [float] the learning rate
:param train_type: [str] Either pretraining (pre) or finetuning (finetune)
:param early_stop_patience [int] Number of epochs with no improvement after which training will be stopped.
:param limit_pretrain [bool] If true, limits pretraining to just the training partition. If false (default), pretrains on all available data.
:return: [tf model] Model
:param early_stop_patience [int] Number of epochs with no improvement after
which training will be stopped. Default is none meaning that training will
continue for all specified epochs
:param use_cpu: [bool] If True, ensures that training happens on CPU. This
can be desirable in some cases (e.g., when using the GW loss function)
:return: [tf model] trained model
"""
if train_type not in ['pre','finetune']:
raise ValueError("Specify train_type as either pre or finetune")

if tf.test.gpu_device_name():
print("Default GPU Device: {}".format(tf.test.gpu_device_name()))
else:
print("Not using GPU")

start_time = datetime.datetime.now()
io_data = get_data_if_file(io_data)

n_seg = len(np.unique(io_data["ids_trn"]))
if n_seg > 1:
batch_size = n_seg
else:
num_years = io_data["x_trn"].shape[0]
batch_size = num_years

if model_type == "lstm":
model = LSTMModel(
hidden_units,
num_tasks=num_tasks,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
elif model_type == "rgcn":
dist_matrix = io_data["dist_matrix"]
model = RGCNModel(
hidden_units,
num_tasks=num_tasks,
A=dist_matrix,
rand_seed=seed,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
)
elif model_type == "gru":
model = GRUModel(
hidden_units,
num_tasks=num_tasks,
recurrent_dropout=recurrent_dropout,
dropout=dropout,
)
else:
raise ValueError(
f"The 'model_type' provided ({model_type}) is not supported"
)

if seed:
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
Expand All @@ -185,89 +53,68 @@ def train_model(
np.random.seed(seed)
random.seed(seed)

optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

# pretrain
if train_type == 'pre':

## Create a dummy directory for the snakemake if you don't want pre-training
if epochs == 0:
os.makedirs(os.path.join(out_dir, "pretrained_weights/"), exist_ok = True)
print("Dummy directory created without pretraining. Set epochs to >0 to pretrain")
else:
# Pull out variables from the IO data
if limit_pretrain:
x_trn_pre = io_data["x_trn"]
y_trn_pre = io_data["y_pre_trn"]
else:
x_trn_pre = io_data["x_pre_full"]
y_trn_pre = io_data["y_pre_full"]

# Initialize our model within the training engine
engine = trainer(model, optimizer, loss_func)

# Call the pretraining routine from the training engine
model = engine.pre_train(x_trn_pre,y_trn_pre,epochs, batch_size,out_dir)

# Log our training times
pre_train_time = datetime.datetime.now()
pre_train_time_elapsed = pre_train_time - start_time
print(f"Pretraining time: {pre_train_time_elapsed}")
out_time_file = os.path.join(out_dir, "training_time.txt")

with open(out_time_file, "w") as f:
f.write(
f"elapsed time pretrain (includes building graph):\
{pre_train_time_elapsed} \n"
)
# Set up early stopping rounds if desired, setting this to the total number
# of epochs is the same as not using it
callbacks = []
if early_stop_patience:
early_stop = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=early_stop_patience, restore_best_weights=False,
verbose=1)
callbacks.append(early_stop)

# finetune
if train_type == 'finetune':

# Load pretrain weights if they exist
if os.path.exists(os.path.join(out_dir, "pretrained_weights/checkpoint")):
weights = os.path.join(out_dir, "pretrained_weights/")
else:
weights = None
#model.load_weights(os.path.join(out_dir, "pretrained_weights/"))
if log_file:
csv_log = tf.keras.callbacks.CSVLogger(log_file)
callbacks.append(csv_log)

# Initialize our model within the training engine
engine = trainer(model, optimizer, loss_func, weights)
print(best_val_weight_dir)
if best_val_weight_dir:
best_val = tf.keras.callbacks.ModelCheckpoint(
best_val_weight_dir, monitor='val_loss', verbose=0, save_best_only=True,
save_weights_only=True, mode='min', save_freq='epoch')
callbacks.append(best_val)

# Specify our variables
y_trn_obs = io_data["y_obs_trn"]
x_trn = io_data["x_trn"]
y_val_obs = io_data['y_obs_val']
x_val = io_data['x_val']

if "GW_trn_reshape" in io_data.files:
temp_air_index = np.where(io_data['x_vars'] == 'seg_tave_air')[0]
air_unscaled = io_data['x_trn'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + \
io_data['x_mean'][temp_air_index]
y_trn_obs = np.concatenate(
[io_data["y_obs_trn"], io_data["GW_trn_reshape"], air_unscaled], axis=2
)
air_val = io_data['x_val'][:, :, temp_air_index] * io_data['x_std'][temp_air_index] + io_data['x_mean'][
temp_air_index]
y_val_obs = np.concatenate(
[io_data["y_obs_val"], io_data["GW_val_reshape"], air_val], axis=2
)
# Run the finetuning within the training engine on CPU for the GW loss function
model = engine.fine_tune(x_trn, y_trn_obs, x_val, y_val_obs, epochs, batch_size, out_dir, early_stop_patience, use_cpu=True)

else:
# Run the finetuning within the training engine on default device
model = engine.fine_tune(x_trn, y_trn_obs, x_val, y_val_obs, epochs, batch_size, out_dir, early_stop_patience)
if isinstance(x_val, np.ndarray) and isinstance(y_val, np.ndarray):
validation_data = (x_val, y_val)
else:
validation_data = None

# Log our training time
finetune_time = datetime.datetime.now()
finetune_time_elapsed = finetune_time - start_time
print(f"Finetuning time: {finetune_time_elapsed}")
out_time_file = os.path.join(out_dir, "training_time.txt")
with open(out_time_file, "a") as f:
f.write(
f"elapsed time finetune (includes building graph):\
{finetune_time_elapsed}\n"
# train the model
start_time = datetime.datetime.now()
if use_cpu:
with tf.device('/CPU:0'):
model.fit(
x=x_trn,
y=y_trn,
epochs=epochs,
batch_size=batch_size,
callbacks=callbacks,
validation_data=validation_data
)
else:
model.fit(
x=x_trn,
y=y_trn,
epochs=epochs,
batch_size=batch_size,
callbacks=callbacks,
validation_data=validation_data
)

# write out time
end_time = datetime.datetime.now()
time_elapsed = end_time - start_time
print(f"Training time: {time_elapsed}")
if time_file:
with open(time_file, "a") as f:
f.write(f"elapsed training time: {time_elapsed}\n")


# Save our trained weights
if weight_dir:
model.save_weights(weight_dir)

# Save alternate weight file that saves the best validation weights

return model

0 comments on commit 90e5716

Please sign in to comment.