Skip to content

Commit

Permalink
[USGS-R#146] option to pass tf model to predict fxns
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Dec 30, 2021
1 parent 829a4eb commit 5900421
Showing 1 changed file with 72 additions and 51 deletions.
123 changes: 72 additions & 51 deletions river_dl/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ def unscale_output(y_scl, y_std, y_mean, y_vars, log_vars=None):
unscale output data given a standard deviation and a mean value for the
outputs
:param y_scl: [pd dataframe] scaled output data (predicted or observed)
:param y_std:[numpy array] array of standard deviation of variables_to_log [n_out]
:param y_std:[numpy array] array of standard deviation of variables_to_log
[n_out]
:param y_mean:[numpy array] array of variable means [n_out]
:param y_vars: [list-like] y_dataset variable names
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param log_vars: [list-like] which variables_to_log (if any) were logged in
data prep
:return: unscaled data
"""
y_unscaled = y_scl.copy()
Expand Down Expand Up @@ -64,41 +65,52 @@ def load_model_from_weights(


def predict_from_io_data(
model_type,
model_weights_dir,
hidden_size,
io_data,
partition,
outfile,
model=None,
model_type=None,
model_weights_dir=None,
hidden_size=None,
partition=None,
outfile=None,
log_vars=False,
num_tasks=1,
trn_offset = 1.0,
tst_val_offset = 1.0,
trn_offset=1.0,
tst_val_offset=1.0,
):
"""
make predictions from trained model
make predictions from trained model.You can pass either
1) compiled model with weights loaded as the "model" argument
2) "model_type", associated parameters, and weights directory. the model
will be compiled and the weights loaded
:param io_data: [str] directory to prepped data file
:param model: [compiled TF model] a TF model compiled with weights loaded
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
:param model_weights_dir: [str] directory to saved model weights
:param io_data: [str] directory to prepped data file
:param hidden_size: [int] the number of hidden units in model
:param partition: [str] must be 'trn' or 'tst'; whether you want to predict
for the train or the dev period
:param outfile: [str] the file where the output data should be stored
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param log_vars: [list-like] which variables_to_log (if any) were logged in
data prep
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:param trn_offset: [str] value for the training offset
:param tst_val_offset: [str] value for the testing and validation offset
:return: [pd dataframe] predictions
"""
if not model and not model_type:
raise ValueError("You must either pass a model or a model_type")

io_data = get_data_if_file(io_data)
model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
io_data.get("dist_matrix"),
num_tasks=num_tasks,
)

if not model:
model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
io_data.get("dist_matrix"),
num_tasks=num_tasks,
)

if partition == "trn":
keep_frac = trn_offset
Expand Down Expand Up @@ -146,8 +158,8 @@ def predict(
:param y_means:[np array] the means of the y_dataset data
:param y_vars:[np array] the variable names of the y_dataset data
:param outfile: [str] the file where the output data should be stored
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param log_vars: [list-like] which variables_to_log (if any) were logged in
data prep
:return: out predictions
"""
num_segs = len(np.unique(pred_ids))
Expand Down Expand Up @@ -234,8 +246,8 @@ def predict_one_date_range(
index (e.g., 'seg_id_nat')
:param time_idx_name: [str] name of column that is used for temporal index
(usually 'time')
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param log_vars: [list-like] which variables_to_log (if any) were logged in
data prep
:param keep_last_frac: [float] fraction of the predictions to keep starting
from the *end* of the predictions (0-1). (1 means you keep all of the
predictions, .75 means you keep the final three quarters of the predictions)
Expand Down Expand Up @@ -299,9 +311,10 @@ def predict_from_arbitrary_data(
pred_start_date,
pred_end_date,
train_io_data,
model_weights_dir,
model_type,
hidden_size,
model=None,
model_weights_dir=None,
model_type=None,
hidden_size=None,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
seq_len=365,
Expand All @@ -311,16 +324,21 @@ def predict_from_arbitrary_data(
):
"""
make predictions given raw data that is potentially independent from the
data used to train the model
data used to train the model. For the model, you can pass either
1) compiled model with weights loaded as the "model" argument
2) "model_type", associated parameters, and weights directory. the model
will be compiled and the weights loaded
:param raw_data_file: [str] path to zarr dataset with x data that you want
to use to make predictions
:param pred_start_date: [str] start date of predictions (fmt: YYYY-MM-DD)
:param pred_end_date: [str] end date of predictions (fmt: YYYY-MM-DD)
:param train_io_data: [str or np NpzFile] the path to or the loaded data
that was used to train the model. This file must contain the variables_to_log
names, the standard deviations, and the means of the X and Y variables_to_log. Only
in with this information can the model be used properly
that was used to train the model. This file must contain the
variables_to_log names, the standard deviations, and the means of the X and
Y variables_to_log. Only with this information can the model be used
properly
:param model: [compiled TF model] a TF model compiled with weights loaded
:param model_weights_dir: [str] path to the directory where the TF model
weights are stored
:param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru'
Expand All @@ -332,29 +350,32 @@ def predict_from_arbitrary_data(
:param seq_len: [int] length of input sequences given to model
:param dist_matrix: [np array] the distance matrix if using 'rgcn'. if not
provided, will look for it in the "train_io_data" file.
:param flow_in_temp: [bool] whether the flow should be an input into temp
for the rgcn model
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
prep
:param log_vars: [list-like] which variables_to_log (if any) were logged in
data prep
:param num_tasks: [int] number of tasks (variables_to_log to be predicted)
:return: [pd dataframe] the predictions
"""
if not model and not model_type:
raise ValueError("You must either pass a model or a model_type")

train_io_data = get_data_if_file(train_io_data)

if model_type == "rgcn":
if not dist_matrix:
dist_matrix = train_io_data.get("dist_matrix")
if not isinstance(dist_matrix, np.ndarray):
raise ValueError(
"model type is 'rgcn', but there is no" "distance matrix"
)

model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
dist_matrix,
num_tasks=num_tasks,
)
if not model:
if model_type == "rgcn":
if not dist_matrix:
dist_matrix = train_io_data.get("dist_matrix")
if not isinstance(dist_matrix, np.ndarray):
raise ValueError(
"model type is 'rgcn', but there is no" "distance matrix"
)

model = load_model_from_weights(
model_type,
model_weights_dir,
hidden_size,
dist_matrix,
num_tasks=num_tasks,
)

ds = xr.open_zarr(raw_data_file)

Expand Down

0 comments on commit 5900421

Please sign in to comment.