diff --git a/Snakefile b/Snakefile index e8a0256..5e2b1ae 100644 --- a/Snakefile +++ b/Snakefile @@ -1,6 +1,6 @@ import os -from river_dl.preproc_utils import prep_data +from river_dl.preproc_utils import prep_all_data from river_dl.evaluate import combined_metrics from river_dl.postproc_utils import plot_obs from river_dl.predict import predict_from_io_data @@ -20,15 +20,20 @@ rule all: rule prep_io_data: input: - config['obs_temp'], - config['obs_flow'], config['sntemp_file'], + config['obs_file'], config['dist_matrix'], output: "{outdir}/prepped.npz" run: - prep_data(input[0], input[1], input[2], input[3], + prep_all_data( + x_data_file=input[0], + pretrain_file=input[0], + y_data_file=input[1], + distfile=input[2], x_vars=config['x_vars'], + y_vars_pretrain=['seg_tave_water', 'seg_outflow'], + y_vars_finetune=['temp_c', 'discharge_cms'], catch_prop_file=None, exclude_file=None, train_start_date=config['train_start_date'], @@ -37,8 +42,7 @@ rule prep_io_data: val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], - primary_variable=config['primary_variable'], - log_q=False, segs=None, + segs=None, out_file=output[0]) @@ -90,7 +94,7 @@ rule make_predictions: predict_from_io_data(model_type='rgcn', model_weights_dir=model_dir, hidden_size=config['hidden_size'], io_data=input[1], partition=wildcards.partition, outfile=output[0], - logged_q=False, num_tasks=2) + num_tasks=2) def get_grp_arg(wildcards): @@ -106,8 +110,7 @@ def get_grp_arg(wildcards): rule combine_metrics: input: - config['obs_temp'], - config['obs_flow'], + config['obs_file'], "{outdir}/trn_preds.feather", "{outdir}/val_preds.feather" output: @@ -116,10 +119,9 @@ rule combine_metrics: params: grp_arg = get_grp_arg run: - combined_metrics(obs_temp=input[0], - obs_flow=input[1], - pred_trn=input[2], - pred_val=input[3], + combined_metrics(obs_file=input[0], + pred_trn=input[1], + pred_val=input[2], group=params.grp_arg, outfile=output[0]) diff --git a/config.yml b/config.yml index 22f22d6..ee00448 100644 --- a/config.yml +++ b/config.yml @@ -1,15 +1,13 @@ # Input files -obs_flow: "data_DRB/obs_flow_subset" -obs_temp: "data_DRB/obs_temp_subset" -sntemp_file: "data_DRB/uncal_sntemp_input_output_subset" -dist_matrix: "data_DRB/distance_matrix_subset.npz" +obs_file: "../drb-dl-model/data/in/obs_flow_temp_subset" +sntemp_file: "../drb-dl-model/data/in/uncal_sntemp_input_output_subset" +dist_matrix: "../drb-dl-model/data/in/distance_matrix_subset.npz" -out_dir: "output_DRB_test" +out_dir: "out_103_test" code_dir: "river_dl" x_vars: ["seg_rain", "seg_tave_air", "seginc_swrad", "seg_length", "seginc_potet", "seg_slope", "seg_humid", "seg_elev"] -primary_variable: "temp" lambdas: [100,100] @@ -29,6 +27,6 @@ test_end_date: - '1985-09-30' - '2021-09-30' -pt_epochs: 20 -ft_epochs: 10 +pt_epochs: 2 +ft_epochs: 2 hidden_size: 20 diff --git a/river_dl/evaluate.py b/river_dl/evaluate.py index edd08f2..07cbbc6 100644 --- a/river_dl/evaluate.py +++ b/river_dl/evaluate.py @@ -8,8 +8,8 @@ def filter_negative_preds(y_true, y_pred): """ filters out negative predictions and prints a warning if there are >5% of predictions as negative - :param y_true: [array-like] observed y values - :param y_pred: [array-like] predicted y values + :param y_true: [array-like] observed y_dataset values + :param y_pred: [array-like] predicted y_dataset values :return: [array-like] filtered data """ # print a warning if there are a lot of negatives @@ -29,8 +29,8 @@ def filter_negative_preds(y_true, y_pred): def rmse_logged(y_true, y_pred): """ compute the rmse of the logged data - :param y_true: [array-like] observed y values - :param y_pred: [array-like] predicted y values + :param y_true: [array-like] observed y_dataset values + :param y_pred: [array-like] predicted y_dataset values :return: [float] the rmse of the logged data """ y_true, y_pred = filter_negative_preds(y_true, y_pred) @@ -40,8 +40,8 @@ def rmse_logged(y_true, y_pred): def nse_logged(y_true, y_pred): """ compute the rmse of the logged data - :param y_true: [array-like] observed y values - :param y_pred: [array-like] predicted y values + :param y_true: [array-like] observed y_dataset values + :param y_pred: [array-like] predicted y_dataset values :return: [float] the nse of the logged data """ y_true, y_pred = filter_negative_preds(y_true, y_pred) @@ -52,8 +52,8 @@ def filter_by_percentile(y_true, y_pred, percentile, less_than=True): """ filter an array by a percentile of the observations. The data less than or greater than if `less_than=False`) will be changed to NaN - :param y_true: [array-like] observed y values - :param y_pred: [array-like] predicted y values + :param y_true: [array-like] observed y_dataset values + :param y_pred: [array-like] predicted y_dataset values :param percentile: [number] percentile number 0-100 :param less_than: [bool] whether you want the data *less than* the percentile. If False, the data greater than the percentile will remain. @@ -72,8 +72,8 @@ def filter_by_percentile(y_true, y_pred, percentile, less_than=True): def percentile_metric(y_true, y_pred, metric, percentile, less_than=True): """ compute an evaluation metric for a specified percentile of the observations - :param y_true: [array-like] observed y values - :param y_pred: [array-like] predicted y values + :param y_true: [array-like] observed y_dataset values + :param y_pred: [array-like] predicted y_dataset values :param metric: [function] metric function :param percentile: [number] percentile number 0-100 :param less_than: [bool] whether you want the data *less than* the @@ -134,16 +134,25 @@ def calc_metrics(df): return pd.Series(metrics) -def overall_metrics( - pred_file, obs_file, variable, partition, group=None, outfile=None +def partition_metrics( + pred_file, + obs_file, + partition, + spatial_idx_name="seg_id_nat", + time_idx_name="date", + group=None, + outfile=None ): """ calculate metrics for a certain group (or no group at all) for a given partition and variable :param pred_file: [str] path to predictions feather file :param obs_file: [str] path to observations zarr file - :param variable: [str] variable for which the metrics are being calculated :param partition: [str] data partition for which metrics are calculated + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') :param group: [str or list] which group the metrics should be computed for. Currently only supports 'seg_id_nat' (segment-wise metrics), 'month' (month-wise metrics), ['seg_id_nat', 'month'] (metrics broken out by segment @@ -151,52 +160,67 @@ def overall_metrics( :param outfile: [str] file where the metrics should be written :return: [pd dataframe] the condensed metrics """ - data = fmt_preds_obs(pred_file, obs_file, variable) - data.reset_index(inplace=True) - if not group: - metrics = calc_metrics(data) - # need to convert to dataframe and transpose so it looks like the others - metrics = pd.DataFrame(metrics).T - elif group == "seg_id_nat": - metrics = data.groupby("seg_id_nat").apply(calc_metrics).reset_index() - elif group == "month": - metrics = ( - data.groupby(data["date"].dt.month) + var_data = fmt_preds_obs(pred_file, obs_file, spatial_idx_name, + time_idx_name) + var_metrics_list = [] + + for data_var, data in var_data.items(): + data.reset_index(inplace=True) + if not group: + metrics = calc_metrics(data) + # need to convert to dataframe and transpose so it looks like the + # others + metrics = pd.DataFrame(metrics).T + elif group == "seg_id_nat": + metrics = data.groupby(spatial_idx_name).apply(calc_metrics).reset_index() + elif group == "month": + metrics = ( + data.groupby( + data[time_idx_name].dt.month) .apply(calc_metrics) .reset_index() - ) - elif group == ["seg_id_nat", "month"]: - metrics = ( - data.groupby([data["date"].dt.month, "seg_id_nat"]) + ) + elif group == ["seg_id_nat", "month"]: + metrics = ( + data.groupby( + [data[time_idx_name].dt.month, + spatial_idx_name]) .apply(calc_metrics) .reset_index() - ) - else: - raise ValueError("group value not valid") - metrics["variable"] = variable - metrics["partition"] = partition + ) + else: + raise ValueError("group value not valid") + + metrics["variable"] = data_var + metrics["partition"] = partition + var_metrics_list.append(metrics) + var_metrics = pd.concat(var_metrics_list) if outfile: - metrics.to_csv(outfile, header=True, index=False) - return metrics + var_metrics.to_csv(outfile, header=True, index=False) + return var_metrics def combined_metrics( - obs_temp, - obs_flow, + obs_file, pred_trn=None, pred_val=None, pred_tst=None, + spatial_idx_name="seg_id_nat", + time_idx_name="date", group=None, outfile=None, ): """ calculate the metrics for flow and temp and training and test sets for a given grouping + :param obs_file: [str] path to observations zarr file :param pred_trn: [str] path to training prediction feather file :param pred_val: [str] path to validation prediction feather file :param pred_tst: [str] path to testing prediction feather file - :param obs_temp: [str] path to observations temperature zarr file - :param obs_flow: [str] path to observations flow zarr file + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') :param group: [str or list] which group the metrics should be computed for. Currently only supports 'seg_id_nat' (segment-wise metrics), 'month' (month-wise metrics), ['seg_id_nat', 'month'] (metrics broken out by segment @@ -206,17 +230,29 @@ def combined_metrics( """ df_all = [] if pred_trn: - trn_temp = overall_metrics(pred_trn, obs_temp, "temp", "trn", group) - trn_flow = overall_metrics(pred_trn, obs_flow, "flow", "trn", group) - df_all.extend([trn_temp, trn_flow]) + trn_metrics = partition_metrics(pred_file=pred_trn, + obs_file=obs_file, + partition="trn", + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + group=group) + df_all.extend([trn_metrics]) if pred_val: - val_temp = overall_metrics(pred_val, obs_temp, "temp", "val", group) - val_flow = overall_metrics(pred_val, obs_flow, "flow", "val", group) - df_all.extend([val_temp, val_flow]) + val_metrics = partition_metrics(pred_file=pred_val, + obs_file=obs_file, + partition="val", + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + group=group) + df_all.extend([val_metrics]) if pred_tst: - tst_temp = overall_metrics(pred_tst, obs_temp, "temp", "tst", group) - tst_flow = overall_metrics(pred_tst, obs_flow, "flow", "tst", group) - df_all.extend([tst_temp, tst_flow]) + tst_metrics = partition_metrics(pred_file=pred_tst, + obs_file=obs_file, + partition="tst", + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + group=group) + df_all.extend([tst_metrics]) df_all = pd.concat(df_all, axis=0) if outfile: df_all.to_csv(outfile, index=False) diff --git a/river_dl/gw_utils.py b/river_dl/gw_utils.py index a8d5d45..b44f88b 100644 --- a/river_dl/gw_utils.py +++ b/river_dl/gw_utils.py @@ -241,7 +241,7 @@ def prep_annual_signal_data( :param io_data_file: [str] the prepped data file :param train_start_date, train_end_date, val_start_date,val_end_date,test_start_date,test_end_date: [str] the start and end dates of the training, validation, and testing periods - :param gwVarList: [str] list of groundwater-relevant variables + :param gwVarList: [str] list of groundwater-relevant variables_to_log :param out_file: [str] file to where the values will be written :param water_temp_pbm_col: str with the column name of the process-based model predicted water temperatures in degrees C :param water_temp_obs_col: str with the column name of the observed water temperatures in degrees C @@ -288,7 +288,7 @@ def prep_annual_signal_data( GW_trn_scale['Ar_obs'] = (GW_trn['Ar_obs']-np.nanmean(GW_trn['Ar_obs']))/np.nanstd(GW_trn['Ar_obs']) GW_trn_scale['delPhi_obs'] = (GW_trn['delPhi_obs']-np.nanmean(GW_trn['delPhi_obs']))/np.nanstd(GW_trn['delPhi_obs']) - #add the GW data to the y dataset + #add the GW data to the y_dataset dataset preppedData = np.load(io_data_file) data = {k:v for k, v in preppedData.items() if not k.startswith("GW")} data['GW_trn_reshape']=make_GW_dataset(GW_trn_scale,obs_trn,gwVarList) @@ -363,7 +363,7 @@ def make_GW_dataset (GW_data,x_data,varList): prepares a GW-relevant dataset for the GW loss function that can be combined with y_true :param GW_data: [dataframe] dataframe of annual temperature signal properties by segment :param x_data: [str] observation dataset - :param varList: [str] variables to keep in the final dataset + :param varList: [str] variables_to_log to keep in the final dataset :returns: GW dataset that is reshaped to match the shape of the first 2 dimensions of the y_true dataset """ #make a dataframe with all combinations of segment and date and then join the annual temperature signal properties dataframe to it @@ -473,12 +473,12 @@ def calc_gw_metrics(trnFile,tstFile,valFile,outFile,figFile1, figFile2, pbm_name for x in range(len(thisData['{}_obs'.format(thisMetric)])): thisColor = colorDict[thisData.group[x]] ax.plot([thisData['{}_obs'.format(thisMetric+"_low")][x],thisData['{}_obs'.format(thisMetric+"_high")][x]],[thisData['{}_pred'.format(thisMetric)][x],thisData['{}_pred'.format(thisMetric)][x]], color=thisColor) -# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y=thisData['{}_pred'.format(thisMetric)],label="RGCN",color="blue") +# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y_dataset=thisData['{}_pred'.format(thisMetric)],label="RGCN",color="blue") for thisGroup in np.unique(thisData['group']): thisColor = colorDict[thisGroup] ax.scatter(x=thisData.loc[thisData.group==thisGroup,'{}_obs'.format(thisMetric)],y=thisData.loc[thisData.group==thisGroup,'{}_pred'.format(thisMetric)],label="RGCN - %s"%thisGroup,color=thisColor) -# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y=thisData['{}_sntemp'.format(thisMetric)],label="SNTEMP",color="red") +# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y_dataset=thisData['{}_sntemp'.format(thisMetric)],label="SNTEMP",color="red") for i, label in enumerate(thisData.seg_id_nat): ax.annotate(int(label), (thisData['{}_obs'.format(thisMetric)][i],thisData['{}_pred'.format(thisMetric)][i])) if thisFig==1: diff --git a/river_dl/loss_functions.py b/river_dl/loss_functions.py index 1e5f5d3..ea57129 100644 --- a/river_dl/loss_functions.py +++ b/river_dl/loss_functions.py @@ -88,7 +88,7 @@ def multitask_kge(lambdas): def multitask_loss(lambdas, loss_func): """ - calculate a weighted multi-task loss for a given number of variables with a + calculate a weighted multi-task loss for a given number of variables_to_log with a given loss function :param lambdas: [array-like float] The factor that losses will be multiplied by before being added together. diff --git a/river_dl/postproc_utils.py b/river_dl/postproc_utils.py index 6b4b517..6f990fc 100644 --- a/river_dl/postproc_utils.py +++ b/river_dl/postproc_utils.py @@ -27,38 +27,48 @@ def load_if_not_df(pred_data): else: return pred_data - -def trim_obs(obs, preds): +def trim_obs(obs, preds, spatial_idx_name="seg_id_nat", time_idx_name="date"): obs_trim = obs.reset_index() trim_preds = preds.reset_index() obs_trim = obs_trim[ - (obs_trim.date >= trim_preds.date.min()) - & (obs_trim.date <= trim_preds.date.max()) - & (obs_trim.seg_id_nat.isin(trim_preds.seg_id_nat.unique())) + (obs_trim[time_idx_name] >= trim_preds[time_idx_name].min()) + & (obs_trim[time_idx_name] <= trim_preds[time_idx_name].max()) + & (obs_trim[spatial_idx_name].isin(trim_preds[spatial_idx_name].unique())) ] - return obs_trim.set_index(["date", "seg_id_nat"]) + return obs_trim.set_index([time_idx_name, spatial_idx_name]) -def fmt_preds_obs(pred_data, obs_file, variable): +def fmt_preds_obs(pred_data, + obs_file, + spatial_idx_name="seg_id_nat", + time_idx_name="date"): """ combine predictions and observations in one dataframe :param pred_data:[str] filepath to the predictions file :param obs_file:[str] filepath to the observations file - :param variable: [str] either 'flow' or 'temp' + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') """ - obs_var, seg_var = get_var_names(variable) pred_data = load_if_not_df(pred_data) - # pred_data.loc[:, "seg_id_nat"] = pred_data["seg_id_nat"].astype(int) - if {"date", "seg_id_nat"}.issubset(pred_data.columns): - pred_data.set_index(["date", "seg_id_nat"], inplace=True) + + if {time_idx_name, spatial_idx_name}.issubset(pred_data.columns): + pred_data.set_index([time_idx_name, spatial_idx_name], inplace=True) obs = xr.open_zarr(obs_file).to_dataframe() - obs_cln = obs[[obs_var]] - obs_cln.columns = ["obs"] - preds = pred_data[[seg_var]] - preds.columns = ["pred"] - obs_cln_trim = trim_obs(obs_cln, preds) - combined = preds.join(obs_cln_trim) - return combined + variables_data = {} + + for var_name in pred_data.columns: + obs_var = obs.copy() + obs_var = obs_var[[var_name]] + obs_var.columns = ["obs"] + preds_var = pred_data[[var_name]] + preds_var.columns = ["pred"] + # trimming obs to preds speeds up following join greatly + obs_var = trim_obs(obs_var, preds_var, spatial_idx_name, time_idx_name) + combined = preds_var.join(obs_var) + variables_data[var_name] = combined + return variables_data def plot_obs(prepped_data, variable, outfile, partition="trn"): @@ -98,9 +108,9 @@ def plot_ts(pred_file, obs_file, variable, out_file): def prepped_array_to_df(data_array, dates, ids, col_names): """ - convert prepped x or y data in numpy array to pandas df + convert prepped x or y_dataset data in numpy array to pandas df (reshape and make into pandas DFs) - :param data_array:[numpy array] array of x or y data [nbatch, seq_len, + :param data_array:[numpy array] array of x or y_dataset data [nbatch, seq_len, n_out] :param dates:[numpy array] array of dates [nbatch, seq_len, n_out] :param ids: [numpy array] array of seg_ids [nbatch, seq_len, n_out] diff --git a/river_dl/predict.py b/river_dl/predict.py index 45e07a4..4788dd5 100644 --- a/river_dl/predict.py +++ b/river_dl/predict.py @@ -14,25 +14,25 @@ from river_dl.train import get_data_if_file -def unscale_output(y_scl, y_std, y_mean, y_vars, logged_q=False): +def unscale_output(y_scl, y_std, y_mean, y_vars, log_vars=None): """ unscale output data given a standard deviation and a mean value for the outputs :param y_scl: [pd dataframe] scaled output data (predicted or observed) - :param y_std:[numpy array] array of standard deviation of variables [n_out] + :param y_std:[numpy array] array of standard deviation of variables_to_log [n_out] :param y_mean:[numpy array] array of variable means [n_out] - :param y_vars: [list-like] y variable names - :param logged_q: [bool] whether the model predicted log of discharge. if - true, the exponent of the discharge will be executed - :return: + :param y_vars: [list-like] y_dataset variable names + :param log_vars: [list-like] which variables_to_log (if any) were logged in data + prep + :return: unscaled data """ y_unscaled = y_scl.copy() # I'm replacing just the variable columns. I have to specify because, at # least in some cases, there are other columns (e.g., "seg_id_nat" and # date") y_unscaled[y_vars] = (y_scl[y_vars] * y_std) + y_mean - if logged_q: - y_unscaled["seg_outflow"] = np.exp(y_unscaled["seg_outflow"]) + if log_vars: + y_unscaled[log_vars] = np.exp(y_unscaled[log_vars]) return y_unscaled @@ -45,7 +45,7 @@ def load_model_from_weights( :param model_weights_dir: [str] directory to saved model weights :param hidden_size: [int] the number of hidden units in model :param dist_matrix: [np array] the distance matrix if using 'rgcn' - :param num_tasks: [int] number of tasks (variables to be predicted) + :param num_tasks: [int] number of tasks (variables_to_log to be predicted) :return: TF model """ if model_type == "rgcn": @@ -70,8 +70,8 @@ def predict_from_io_data( io_data, partition, outfile, + log_vars=False, num_tasks=1, - logged_q=False, ): """ make predictions from trained model @@ -82,9 +82,9 @@ def predict_from_io_data( :param partition: [str] must be 'trn' or 'tst'; whether you want to predict for the train or the dev period :param outfile: [str] the file where the output data should be stored - :param logged_q: [bool] whether the discharge was logged in training. if - True the exponent of the discharge will be taken in the model unscaling - :param num_tasks: [int] number of tasks (variables to be predicted) + :param log_vars: [list-like] which variables_to_log (if any) were logged in data + prep + :param num_tasks: [int] number of tasks (variables_to_log to be predicted) :return: [pd dataframe] predictions """ io_data = get_data_if_file(io_data) @@ -105,13 +105,13 @@ def predict_from_io_data( model, io_data[f"x_{partition}"], io_data[f"ids_{partition}"], - io_data[f"dates_{partition}"], - io_data[f"y_std"], - io_data[f"y_mean"], - io_data[f"y_vars"], + io_data[f"times_{partition}"], + io_data["y_std"], + io_data["y_mean"], + io_data["y_obs_vars"], keep_last_frac=keep_frac, outfile=outfile, - logged_q=logged_q, + log_vars=log_vars, ) return preds @@ -126,7 +126,7 @@ def predict( y_vars, keep_last_frac=1.0, outfile=None, - logged_q=False, + log_vars=False, ): """ use trained model to make predictions @@ -138,12 +138,12 @@ def predict( :param keep_last_frac: [float] fraction of the predictions to keep starting from the *end* of the predictions (0-1). (1 means you keep all of the predictions, .75 means you keep the final three quarters of the predictions) - :param y_stds:[np array] the standard deviation of the y data - :param y_means:[np array] the means of the y data - :param y_vars:[np array] the variable names of the y data + :param y_stds:[np array] the standard deviation of the y_dataset data + :param y_means:[np array] the means of the y_dataset data + :param y_vars:[np array] the variable names of the y_dataset data :param outfile: [str] the file where the output data should be stored - :param logged_q: [str] whether the discharge was logged in training. if True - the exponent of the discharge will be taken in the model unscaling + :param log_vars: [list-like] which variables_to_log (if any) were logged in data + prep :return: out predictions """ num_segs = len(np.unique(pred_ids)) @@ -157,7 +157,7 @@ def predict( y_pred_pp = prepped_array_to_df(y_pred, pred_dates, pred_ids, y_vars,) - y_pred_pp = unscale_output(y_pred_pp, y_stds, y_means, y_vars, logged_q,) + y_pred_pp = unscale_output(y_pred_pp, y_stds, y_means, y_vars, log_vars,) if outfile: y_pred_pp.to_feather(outfile) @@ -209,7 +209,9 @@ def predict_one_date_range( seq_len, start_date, end_date, - logged_q=False, + spatial_idx_name="seg_id_nat", + time_idx_name="date", + log_vars=None, keep_last_frac=1.0, offset=0.5, swap_halves_of_first_seq=False, @@ -224,8 +226,12 @@ def predict_one_date_range( :param seq_len: [int] length of the prediction sequences (usu. 365) :param start_date: [str or date] the start date of the predictions :param end_date: [str or date] the end date of the predictions - :param logged_q: [bool] whether the model predicted log of discharge. if - true, the exponent of the discharge will be executed + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') + :param log_vars: [list-like] which variables_to_log (if any) were logged in data + prep :param keep_last_frac: [float] fraction of the predictions to keep starting from the *end* of the predictions (0-1). (1 means you keep all of the predictions, .75 means you keep the final three quarters of the predictions) @@ -240,12 +246,28 @@ def predict_one_date_range( """ ds_x_scaled = ds_x_scaled[train_io_data["x_cols"]] x_data = ds_x_scaled.sel(date=slice(start_date, end_date)) - x_batches = convert_batch_reshape(x_data, seq_len=seq_len, offset=offset) + x_batches = convert_batch_reshape( + x_data, + seq_len=seq_len, + offset=offset, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + ) x_batch_ids = coord_as_reshaped_array( - x_data, "seg_id_nat", seq_len=seq_len, offset=offset + x_data, + "seg_id_nat", + seq_len=seq_len, + offset=offset, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, ) x_batch_dates = coord_as_reshaped_array( - x_data, "date", seq_len=seq_len, offset=offset + x_data, + "date", + seq_len=seq_len, + offset=offset, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, ) num_segs = len(np.unique(x_batch_ids)) @@ -261,9 +283,9 @@ def predict_one_date_range( x_batch_dates, train_io_data["y_std"], train_io_data["y_mean"], - train_io_data["y_vars"], + train_io_data["y_obs_vars"], keep_last_frac=keep_last_frac, - logged_q=logged_q, + log_vars=log_vars, ) return predictions @@ -276,10 +298,12 @@ def predict_from_arbitrary_data( model_weights_dir, model_type, hidden_size, - num_tasks=1, + spatial_idx_name="seg_id_nat", + time_idx_name="date", seq_len=365, dist_matrix=None, - logged_q=False, + log_vars=None, + num_tasks=1, ): """ make predictions given raw data that is potentially independent from the @@ -290,19 +314,24 @@ def predict_from_arbitrary_data( :param pred_start_date: [str] start date of predictions (fmt: YYYY-MM-DD) :param pred_end_date: [str] end date of predictions (fmt: YYYY-MM-DD) :param train_io_data: [str or np NpzFile] the path to or the loaded data - that was used to train the model. This file must contain the variables - names, the standard deviations, and the means of the X and Y variables. Only + that was used to train the model. This file must contain the variables_to_log + names, the standard deviations, and the means of the X and Y variables_to_log. Only in with this information can the model be used properly :param model_weights_dir: [str] path to the directory where the TF model weights are stored :param model_type: [str] model to use either 'rgcn', 'lstm', or 'gru' :param hidden_size: [int] the number of hidden units in model - :param num_tasks: [int] number of tasks (variables to be predicted) + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') :param seq_len: [int] length of input sequences given to model :param dist_matrix: [np array] the distance matrix if using 'rgcn'. if not provided, will look for it in the "train_io_data" file. - :param logged_q: [bool] whether the model predicted log of discharge. if - true, the exponent of the discharge will be executed + :param flow_in_temp: [bool] whether the flow should be an input into temp + for the rgcn model + :param log_vars: [list-like] which variables_to_log (if any) were logged in data + prep :return: [pd dataframe] the predictions """ train_io_data = get_data_if_file(train_io_data) @@ -345,7 +374,9 @@ def predict_from_arbitrary_data( seq_len, inputs_start_date, pred_end_date, - logged_q, + log_vars=log_vars, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, keep_last_frac=0.5, offset=0.5, ) @@ -360,7 +391,9 @@ def predict_from_arbitrary_data( seq_len, pred_start_date, start_dates_end, - logged_q, + log_vars=log_vars, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, keep_last_frac=1, offset=0.5, swap_halves_of_first_seq=True, @@ -379,7 +412,9 @@ def predict_from_arbitrary_data( seq_len, end_dates_start, end_date_end, - logged_q, + log_vars=log_vars, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, keep_last_frac=1, offset=1, ) diff --git a/river_dl/preproc_utils.py b/river_dl/preproc_utils.py index 3261642..4896dd9 100644 --- a/river_dl/preproc_utils.py +++ b/river_dl/preproc_utils.py @@ -23,10 +23,12 @@ def scale(dataset, std=None, mean=None): return scaled, std, mean -def sel_partition_data(dataset, start_dates, end_dates): +def sel_partition_data(dataset, time_idx_name, start_dates, end_dates): """ select the data from a date range or a set of date ranges :param dataset: [xr dataset] input or output data with date dimension + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') :param start_dates: [str or list] fmt: "YYYY-MM-DD"; date(s) to start period (can have multiple discontinuos periods) :param end_dates: [str or list] fmt: "YYYY-MM-DD"; date(s) to end period @@ -36,7 +38,7 @@ def sel_partition_data(dataset, start_dates, end_dates): # if it just one date range if isinstance(start_dates, str): if isinstance(end_dates, str): - return dataset.sel(date=slice(start_dates, end_dates)) + return dataset.sel({time_idx_name: slice(start_dates, end_dates)}) else: raise ValueError("start_dates is str but not end_date") # if it's a list of date ranges @@ -45,8 +47,8 @@ def sel_partition_data(dataset, start_dates, end_dates): data_list = [] for i in range(len(start_dates)): date_slice = slice(start_dates[i], end_dates[i]) - data_list.append(dataset.sel(date=date_slice)) - return xr.concat(data_list, dim="date") + data_list.append(dataset.sel({time_idx_name: date_slice})) + return xr.concat(data_list, dim=time_idx_name) else: raise ValueError("start_dates and end_dates must have same length") else: @@ -55,6 +57,7 @@ def sel_partition_data(dataset, start_dates, end_dates): def separate_trn_tst( dataset, + time_idx_name, train_start_date, train_end_date, val_start_date, @@ -67,31 +70,39 @@ def separate_trn_tst( dates. This assumes your training data is in one continuous block and all the dates that are not in the training are in the testing. :param dataset: [xr dataset] input or output data with dims + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') :param train_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - train period (can have multiple discontinuos periods) + train period (can have multiple discontinuous periods) :param train_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end train - period (can have multiple discontinuos periods) + period (can have multiple discontinuous periods) :param val_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - validation period (can have multiple discontinuos periods) + validation period (can have multiple discontinuous periods) :param val_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end - validation period (can have multiple discontinuos periods) + validation period (can have multiple discontinuous periods) :param test_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - test period (can have multiple discontinuos periods) + test period (can have multiple discontinuous periods) :param test_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end test - period (can have multiple discontinuos periods) + period (can have multiple discontinuous periods) """ - train = sel_partition_data(dataset, train_start_date, train_end_date) - val = sel_partition_data(dataset, val_start_date, val_end_date) - test = sel_partition_data(dataset, test_start_date, test_end_date) + train = sel_partition_data( + dataset, time_idx_name, train_start_date, train_end_date + ) + val = sel_partition_data( + dataset, time_idx_name, val_start_date, val_end_date + ) + test = sel_partition_data( + dataset, time_idx_name, test_start_date, test_end_date + ) return train, val, test -def split_into_batches(data_array, seq_len=365, offset=1): +def split_into_batches(data_array, seq_len=365, offset=1.0): """ split training data into batches with size of batch_size :param data_array: [numpy array] array of training data with dims [nseg, ndates, nfeat] - :param seq_len: [int] length of sequences (i.e., 365) + :param seq_len: [int] length of sequences (e.g., 365) :param offset: [float] 0-1, how to offset the batches (e.g., 0.5 means that the first batch will be 0-365 and the second will be 182-547) :return: [numpy array] batched data with dims [nbatches, nseg, seq_len @@ -108,25 +119,19 @@ def split_into_batches(data_array, seq_len=365, offset=1): return combined -def read_multiple_obs(obs_files, x_data): +def read_obs(obs_file, y_vars, x_data): """ read and format multiple observation files. we read in the pretrain data to make sure we have the same indexing. - :param obs_files: [list] list of filenames of observation files - :param pre_train_file: [str] the file of pre_training data + :param x_data: [xr.Dataset] xarray dataset used to match spatial and + temporal domain + :param y_vars: [list of str] which variables_to_log to prepare data for + :param obs_file: [list] filenames of observation file :return: [xr dataset] the observations in the same time """ - obs = [x_data.sortby(["seg_id_nat", "date"])] - for filename in obs_files: - ds = xr.open_zarr(filename) - obs.append(ds) - if "site_id" in ds.variables: - del ds["site_id"] - obs = xr.merge(obs, join="left") - obs = obs[["temp_c", "discharge_cms"]] - obs = obs.rename( - {"temp_c": "seg_tave_water", "discharge_cms": "seg_outflow"} - ) + ds = xr.open_zarr(obs_file) + obs = xr.merge([x_data, ds], join="left") + obs = obs[y_vars] return obs @@ -148,6 +153,8 @@ def prep_catch_props(x_data_ts, catch_prop_file, replace_nan_with_mean=True): read catch property file and join with ts data :param x_data_ts: [xr dataset] timeseries x-data :param catch_prop_file: [str] the feather file of catchment attributes + :param replace_nan_with_mean: [bool] if true, any nan will be replaced with + the mean of that variable :return: [xr dataset] merged datasets """ df_catch_props = pd.read_feather(catch_prop_file) @@ -163,7 +170,7 @@ def prep_catch_props(x_data_ts, catch_prop_file, replace_nan_with_mean=True): def reshape_for_training(data): """ reshape the data for training - :param data: training data (either x or y or mask) dims: [nbatch, nseg, + :param data: training data (either x or y_dataset or mask) dims: [nbatch, nseg, len_seq, nfeat/nout] :return: reshaped data [nbatch * nseg, len_seq, nfeat/nout] """ @@ -190,10 +197,10 @@ def get_exclude_start_end(exclude_grp): def get_exclude_vars(exclude_grp): """ - get the variables to exclude for the exclude group + get the variables_to_log to exclude for the exclude group :param exclude_grp: [dict] dictionary representing the exclude group from the exclude yml file - :return: [list] variables to exclude + :return: [list] variables_to_log to exclude """ variable = exclude_grp.get("variable") if not variable or variable == "both": @@ -230,7 +237,7 @@ def get_exclude_seg_ids(exclude_grp, all_segs): def exclude_segments(y_data, exclude_segs): """ exclude segments from being trained on by setting their weights as zero - :param y_data:[xr dataset] y data. this is used to get the dimensions + :param y_data:[xr dataset] y_dataset data. this is used to get the dimensions :param exclude_segs: [list] list of segments to exclude in the loss calculation :return: @@ -255,7 +262,7 @@ def exclude_segments(y_data, exclude_segs): def initialize_weights(y_data, initial_val=1): """ initialize all weights with a value. - :param y_data:[xr dataset] y data. this is used to get the dimensions + :param y_data:[xr dataset] y_dataset data. this is used to get the dimensions :param initial_val: [num] a number to initialize the weights with. should be between 0 and 1 (inclusive) :return: [xr dataset] dataset weights initialized with a uniform value @@ -285,6 +292,7 @@ def reduce_training_data_random( :param reduce_amount: [float] fraction to reduce the training data by. For example, if 0.9, a random 90% of the training data will be set to nan :param out_file: [str] file to which the reduced dataset will be written + :param segs: [array-like] segments to reduce data of :return: [xarray dataset] updated weights (nan where reduced) """ # read in an convert to dataframe @@ -370,23 +378,33 @@ def reduce_training_data_continuous( return reduced_ds -def convert_batch_reshape(dataset, seq_len=365, offset=1): +def convert_batch_reshape( + dataset, + spatial_idx_name="seg_id_nat", + time_idx_name="date", + seq_len=365, + offset=1.0 +): """ convert xarray dataset into numpy array, swap the axes, batch the array and reshape for training :param dataset: [xr dataset] data to be batched - :param seq_len: [int] length of sequences (i.e., 365) + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') + :param seq_len: [int] length of sequences (e.g., 365) :param offset: [float] 0-1, how to offset the batches (e.g., 0.5 means that the first batch will be 0-365 and the second will be 182-547) :return: [numpy array] batched and reshaped dataset """ # convert xr.dataset to numpy array - dataset = dataset.transpose("seg_id_nat", "date") + dataset = dataset.transpose(spatial_idx_name, time_idx_name) arr = dataset.to_array().values # if the dataset is empty, just return it as is - if dataset.date.size == 0: + if dataset[time_idx_name].size == 0: return arr # before [nfeat, nseg, ndates]; after [nseg, ndates, nfeat] @@ -403,7 +421,27 @@ def convert_batch_reshape(dataset, seq_len=365, offset=1): return reshaped -def coord_as_reshaped_array(dataset, coord_name, seq_len=365, offset=1): +def coord_as_reshaped_array( + dataset, + coord_name, + spatial_idx_name="seg_id_nat", + time_idx_name="date", + seq_len=365, + offset=1.0, +): + """ + convert an xarray coordinate to an xarray data array and reshape that array + :param dataset: + :param coord_name: [str] the name of the coordinate to convert/reshape + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') + :param seq_len: [int] length of sequences (e.g., 365) + :param offset: [float] 0-1, how to offset the batches (e.g., 0.5 means that + the first batch will be 0-365 and the second will be 182-547) + :return: + """ # I need one variable name. It can be any in the dataset, but I'll use the # first first_var = next(iter(dataset.data_vars.keys())) @@ -411,7 +449,11 @@ def coord_as_reshaped_array(dataset, coord_name, seq_len=365, offset=1): new_var_name = coord_name + "1" dataset[new_var_name] = coord_array reshaped_np_arr = convert_batch_reshape( - dataset[[new_var_name]], seq_len=seq_len, offset=offset + dataset[[new_var_name]], + spatial_idx_name, + time_idx_name, + seq_len=seq_len, + offset=offset, ) return reshaped_np_arr @@ -420,35 +462,152 @@ def check_if_finite(xarr): assert np.isfinite(xarr.to_array().values).all() -def log_discharge(y): +def log_variables(y_dataset, variables_to_log): """ - take the log of discharge - :param y: [xr dataset] the y data + take the log of given variables + :param variables_to_log: [list of str] variables to take the log of + :param y_dataset: [xr dataset] the y data :return: [xr dataset] the data logged """ - y["seg_outflow"].load() - y["seg_outflow"].loc[:, :] = y["seg_outflow"] + 1e-6 - y["seg_outflow"].loc[:, :] = xr.ufuncs.log(y["seg_outflow"]) - return y + for v in variables_to_log: + y_dataset[v].load() + y_dataset[v].loc[:, :] = y_dataset[v] + 1e-6 + y_dataset[v].loc[:, :] = xr.ufuncs.log(y_dataset[v]) + return y_dataset -def prep_data( - obs_temper_file, - obs_flow_file, - pretrain_file, - distfile, +def prep_y_data( + y_data_file, + y_vars, + x_data, train_start_date, train_end_date, val_start_date, val_end_date, test_start_date, test_end_date, - x_vars=None, - y_vars=None, - primary_variable="flow", + spatial_idx_name="seg_id_nat", + time_idx_name="date", + seq_len=365, + log_vars=None, + exclude_file=None, + normalize_y=True, + y_type="obs", + y_std=None, + y_mean=None, +): + """ + prepare y_dataset data + + :param y_data_file: [str] temperature observations file + :param y_vars: [str or list of str] target variable(s) + :param x_data: [xr.Dataset] xarray dataset used to match spatial and + temporal domain + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') + :param train_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start + train period (can have multiple discontinuous periods) + :param train_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end train + period (can have multiple discontinuous periods) + :param val_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start + validation period (can have multiple discontinuous periods) + :param val_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end + validation period (can have multiple discontinuous periods) + :param test_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start + test period (can have multiple discontinuous periods) + :param test_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end test + period (can have multiple discontinuous periods) + :param seq_len: [int] length of sequences (e.g., 365) + :param log_vars: [list-like] which variables_to_log (if any) to take log of + :param exclude_file: [str] path to exclude file + :param normalize_y: [bool] whether or not to normalize the y_dataset values + :param y_type: [str] "obs" if observations or "pre" if pretraining + :param y_std: [array-like] standard deviations of y_dataset variables_to_log + :param y_mean: [array-like] means of y_dataset variables_to_log + :returns: training and testing data along with the means and standard + deviations of the training input and output data + """ + # I assume that if `y_vars` is a string only one variable has been passed + # so I put that in a list which is what the rest of the functions expect + if isinstance(y_vars, str): + y_vars = [y_vars] + + y_data = read_obs(y_data_file, y_vars, x_data) + + y_trn, y_val, y_tst = separate_trn_tst( + y_data, + time_idx_name, + train_start_date, + train_end_date, + val_start_date, + val_end_date, + test_start_date, + test_end_date, + ) + + if log_vars: + y_trn = log_variables(y_trn, log_vars) + + # filter pretrain/finetune y_dataset + if exclude_file: + exclude_segs = read_exclude_segs_file(exclude_file) + y_wgts = exclude_segments(y_trn, exclude_segs=exclude_segs) + else: + y_wgts = initialize_weights(y_trn) + + if normalize_y: + # scale y_dataset training data and get the mean and std + if not isinstance(y_std, xr.Dataset) or not isinstance( + y_mean, xr.Dataset + ): + y_trn, y_std, y_mean = scale(y_trn) + else: + y_trn, _, _ = scale(y_trn) + + data = { + f"y_{y_type}_trn": convert_batch_reshape( + y_trn, spatial_idx_name, time_idx_name, seq_len=seq_len + ), + f"y_{y_type}_wgts": convert_batch_reshape( + y_wgts, spatial_idx_name, time_idx_name, seq_len=seq_len + ), + f"y_{y_type}_val": convert_batch_reshape( + y_val, spatial_idx_name, time_idx_name, offset=0.5, seq_len=seq_len + ), + f"y_{y_type}_tst": convert_batch_reshape( + y_tst, spatial_idx_name, time_idx_name, offset=0.5, seq_len=seq_len + ), + "y_std": y_std.to_array().values, + "y_mean": y_mean.to_array().values, + f"y_{y_type}_vars": y_vars, + } + return data + + +def prep_all_data( + x_data_file, + y_data_file, + train_start_date, + train_end_date, + val_start_date, + val_end_date, + test_start_date, + test_end_date, + x_vars, + y_vars_finetune=None, + y_vars_pretrain=None, + spatial_idx_name="seg_id_nat", + time_idx_name="date", + seq_len=365, + pretrain_file=None, + distfile=None, + dist_idx_name="rowcolnames", + dist_type="updown", catch_prop_file=None, exclude_file=None, - log_q=False, + log_y_vars=False, out_file=None, segs=None, normalize_y=True, @@ -457,60 +616,92 @@ def prep_data( prepare input and output data for DL model training read in and process data into training and testing datasets. the training and testing data are scaled to have a std of 1 and a mean of zero - :param obs_temper_file: [str] temperature observations file (csv) - :param obs_flow_file:[str] discharge observations file (csv) - :param pretrain_file: [str] the file with the pretraining data (SNTemp data) - :param distfile: [str] path to the distance matrix .npz file + :param x_data_file: [str] path to Zarr file with x data. Data should have + a spatial coordinate and a time coordinate that are specified in the + `spatial_idx_name` and `time_idx_name` arguments + :param y_data_file: [str] observations Zarr file. Data should have a spatial + coordinate and a time coordinate that are specified in the + spatial_idx_name` and `time_idx_name` arguments :param train_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - train period (can have multiple discontinuos periods) + train period (can have multiple discontinuous periods) :param train_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end train - period (can have multiple discontinuos periods) + period (can have multiple discontinuous periods) :param val_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - validation period (can have multiple discontinuos periods) + validation period (can have multiple discontinuous periods) :param val_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end - validation period (can have multiple discontinuos periods) + validation period (can have multiple discontinuous periods) :param test_start_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to start - test period (can have multiple discontinuos periods) + test period (can have multiple discontinuous periods) :param test_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end test - period (can have multiple discontinuos periods) - :param x_vars: [list] variables that should be used as input. If None, all - of the variables will be used - :param primary_variable: [str] which variable the model should focus on - 'temp' or 'flow'. This determines the order of the variables. + period (can have multiple discontinuous periods) + :param spatial_idx_name: [str] name of column that is used for spatial + index (e.g., 'seg_id_nat') + :param time_idx_name: [str] name of column that is used for temporal index + (usually 'time') + :param x_vars: [list] variables_to_log that should be used as input. If None, all + of the variables_to_log will be used + :param y_vars_finetune: [str or list of str] finetune target variable(s) + :param y_vars_pretrain: [str or list of str] pretrain target variable(s) + :param seq_len: [int] length of sequences (e.g., 365) + :param pretrain_file: [str] Zarr file with the pretraining data. Should have + a spatial coordinate and a time coordinate that are specified in the + `spatial_idx_name` and `time_idx_name` arguments + :param distfile: [str] path to the distance matrix .npz file + :param dist_idx_name: [str] name of index to sort dist_matrix by. This is + the name of an array in the distance matrix .npz file + :param dist_type: [str] type of distance matrix ("upstream", "downstream" or + "updown") :param catch_prop_file: [str] the path to the catchment properties file. If left unfilled, the catchment properties will not be included as predictors :param exclude_file: [str] path to exclude file - :param log_q: [bool] whether or not to take the log of discharge in training + :param log_y_vars: [bool] whether or not to take the log of discharge in + training + :param segs: [list-like] which segments to prepare the data for + :param normalize_y: [bool] whether or not to normalize the y_dataset values :param out_file: [str] file to where the values will be written :returns: training and testing data along with the means and standard deviations of the training input and output data - 'y_trn_pre': batched, scaled, and centered output data for entire - period of record of SNTemp [n_samples, seq_len, n_out] - 'y_obs_trn': batched, scaled, and centered output observation data - for the training period - 'y_trn_obs_std': standard deviation of the y observations training - data [n_out] - 'y_trn_obs_mean': mean of the observation training data [n_out] - 'y_obs_tst': un-batched, unscaled, uncentered observation data for - the test period [n_yrs, n_seg, len_seq, n_out] - 'dates_ids_trn: batched dates and national seg ids for training data - [n_samples, seq_len, 2] - 'dates_ids_tst: un-batched dates and national seg ids for testing - data [n_yrs, n_seg, len_seq, 2] - """ - ds_pre = xr.open_zarr(pretrain_file) + "x_trn": x training data + "x_val": x validation data + "x_tst": x test data + "x_std": x standard deviations + "x_mean": x means + "x_cols": x column names + "ids_trn": segment ids of the training data + "times_trn": dates of the training data + "ids_val": segment ids of the validation data + "times_val": dates of the validation data + "ids_tst": segment ids of the test data + "times_tst": dates of the test data + 'y_pre_trn': y_dataset pretrain data for train set + 'y_obs_trn': y_dataset observations for train set + "y_pre_wgts": y_dataset weights for pretrain data + "y_obs_wgts": weights for y_dataset observations + "y_obs_val": y_dataset observations for train set + "y_obs_tst": y_dataset observations for train set + "y_std": standard deviations of y_dataset data + "y_mean": means of y_dataset data + "y_vars": y_dataset variable names + "dist_matrix": prepared adjacency matrix + """ + if pretrain_file and not y_vars_pretrain: + raise ValueError("included pretrain file but no pretrain vars") + + x_data = xr.open_zarr(x_data_file) + x_data = x_data.sortby([spatial_idx_name, time_idx_name]) if segs: - ds_pre = ds_pre.loc[dict(seg_id_nat=segs)] + x_data = x_data.sel({spatial_idx_name: segs}) - x_data = ds_pre[x_vars] + x_data = x_data[x_vars] if catch_prop_file: x_data = prep_catch_props(x_data, catch_prop_file) - # make sure we don't have any weird input values + # make sure we don't have any weird or missing input values check_if_finite(x_data) x_trn, x_val, x_tst = separate_trn_tst( x_data, + time_idx_name, train_start_date, train_end_date, val_start_date, @@ -526,88 +717,160 @@ def prep_data( x_tst_scl, _, _ = scale(x_tst, std=x_std, mean=x_mean) # read, filter observations for finetuning - if not y_vars: - if primary_variable == "temp": - y_vars = ["seg_tave_water", "seg_outflow"] - else: - y_vars = ["seg_outflow", "seg_tave_water"] - - y_obs = read_multiple_obs([obs_temper_file, obs_flow_file], x_data) - y_obs = y_obs[y_vars] - y_pre = ds_pre[y_vars] - if segs: - y_obs = y_obs.loc[dict(seg_id_nat=segs)] - y_obs_trn, y_obs_val, y_obs_tst = separate_trn_tst( - y_obs, - train_start_date, - train_end_date, - val_start_date, - val_end_date, - test_start_date, - test_end_date, - ) - y_pre_trn, _, _ = separate_trn_tst( - y_pre, - train_start_date, - train_end_date, - val_start_date, - val_end_date, - test_start_date, - test_end_date, - ) - - if log_q: - y_obs_trn = log_discharge(y_obs_trn) - y_pre_trn = log_discharge(y_pre_trn) - - # filter pretrain/finetune y - y_pre_wgts = initialize_weights(y_pre_trn) - if exclude_file: - exclude_segs = read_exclude_segs_file(exclude_file) - y_obs_wgts = exclude_segments(y_obs_trn, exclude_segs=exclude_segs) - else: - y_obs_wgts = initialize_weights(y_obs_trn) - - if normalize_y: - # scale y training data and get the mean and std - y_obs_trn, y_std, y_mean = scale(y_obs_trn) - y_pre_trn, _, _ = scale(y_pre_trn, y_std, y_mean) - else: - _, y_std, y_mean = scale(y_obs_trn) - - data = { - "x_trn": convert_batch_reshape(x_trn_scl), - "x_val": convert_batch_reshape(x_val_scl, offset=0.5), - "x_tst": convert_batch_reshape(x_tst_scl, offset=0.5), + x_data_dict = { + "x_trn": convert_batch_reshape( + x_trn_scl, spatial_idx_name, time_idx_name, seq_len=seq_len + ), + "x_val": convert_batch_reshape( + x_val_scl, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), + "x_tst": convert_batch_reshape( + x_tst_scl, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), "x_std": x_std.to_array().values, "x_mean": x_mean.to_array().values, - "x_cols": np.array(x_vars), - "ids_trn": coord_as_reshaped_array(x_trn, "seg_id_nat"), - "dates_trn": coord_as_reshaped_array(x_trn, "date"), - "ids_val": coord_as_reshaped_array(x_val, "seg_id_nat", offset=0.5), - "dates_val": coord_as_reshaped_array(x_val, "date", offset=0.5), - "ids_tst": coord_as_reshaped_array(x_tst, "seg_id_nat", offset=0.5), - "dates_tst": coord_as_reshaped_array(x_tst, "date", offset=0.5), - "y_pre_trn": convert_batch_reshape(y_pre_trn), - "y_pre_wgts": convert_batch_reshape(y_pre_wgts), - "y_obs_trn": convert_batch_reshape(y_obs_trn), - "y_obs_wgts": convert_batch_reshape(y_obs_wgts), - "y_obs_val": convert_batch_reshape(y_obs_val, offset=0.5), - "y_obs_tst": convert_batch_reshape(y_obs_tst, offset=0.5), - "y_std": y_std.to_array().values, - "y_mean": y_mean.to_array().values, - "y_vars": np.array(y_vars), - "dist_matrix": prep_adj_matrix(distfile, "upstream", segs=segs), + "x_vars": np.array(x_vars), + "ids_trn": coord_as_reshaped_array( + x_trn, + spatial_idx_name, + spatial_idx_name, + time_idx_name, + seq_len=seq_len, + ), + "times_trn": coord_as_reshaped_array( + x_trn, + time_idx_name, + spatial_idx_name, + time_idx_name, + seq_len=seq_len, + ), + "ids_val": coord_as_reshaped_array( + x_val, + spatial_idx_name, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), + "times_val": coord_as_reshaped_array( + x_val, + time_idx_name, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), + "ids_tst": coord_as_reshaped_array( + x_tst, + spatial_idx_name, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), + "times_tst": coord_as_reshaped_array( + x_tst, + time_idx_name, + spatial_idx_name, + time_idx_name, + offset=0.5, + seq_len=seq_len, + ), } + if distfile: + x_data_dict["dist_matrix"] = prep_adj_matrix( + infile=distfile, + dist_type=dist_type, + dist_idx_name=dist_idx_name, + segs=segs, + ) + + y_obs_data = {} + y_pre_data = {} + if y_data_file: + y_obs_data = prep_y_data( + y_data_file=y_data_file, + y_vars=y_vars_finetune, + x_data=x_data, + train_start_date=train_start_date, + train_end_date=train_end_date, + val_start_date=val_start_date, + val_end_date=val_end_date, + test_start_date=test_start_date, + test_end_date=test_end_date, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + seq_len=seq_len, + log_vars=log_y_vars, + exclude_file=exclude_file, + normalize_y=normalize_y, + y_type="obs", + ) + # if there is a y_data_file and a pretrain file, use the observation + # mean and standard deviation to do the scaling/centering + if pretrain_file: + y_pre_data = prep_y_data( + y_data_file=pretrain_file, + y_vars=y_vars_pretrain, + x_data=x_data, + train_start_date=train_start_date, + train_end_date=train_end_date, + val_start_date=val_start_date, + val_end_date=val_end_date, + test_start_date=test_start_date, + test_end_date=test_end_date, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + seq_len=seq_len, + log_vars=log_y_vars, + exclude_file=exclude_file, + normalize_y=normalize_y, + y_type="pre", + y_std=y_obs_data["y_std"], + y_mean=y_obs_data["y_mean"], + ) + # if there is no observation file, use the pretrain mean and standard dev + # to do the scaling/centering + elif pretrain_file and not y_obs_data: + y_pre_data = prep_y_data( + y_data_file=pretrain_file, + y_vars=y_vars_pretrain, + x_data=x_data, + train_start_date=train_start_date, + train_end_date=train_end_date, + val_start_date=val_start_date, + val_end_date=val_end_date, + test_start_date=test_start_date, + test_end_date=test_end_date, + spatial_idx_name=spatial_idx_name, + time_idx_name=time_idx_name, + seq_len=seq_len, + log_vars=log_y_vars, + exclude_file=exclude_file, + normalize_y=normalize_y, + y_type="pre", + ) + else: + raise Warning("No y_dataset data was provided") + + all_data = {**x_data_dict, **y_obs_data, **y_pre_data} if out_file: - np.savez_compressed(out_file, **data) - return data + np.savez_compressed(out_file, **all_data) + return all_data def sort_dist_matrix(mat, row_col_names, segs=None): """ - sort the distance matrix by seg_id_nat + sort the distance matrix by id :return: """ if segs is not None: @@ -621,19 +884,22 @@ def sort_dist_matrix(mat, row_col_names, segs=None): return df -def prep_adj_matrix(infile, dist_type, out_file=None, segs=None): +def prep_adj_matrix(infile, dist_type, dist_idx_name, segs=None, out_file=None): """ process adj matrix. - **The resulting matrix is sorted by seg_id_nat ** - :param infile: + **The resulting matrix is sorted by id ** + :param infile: [str] path to the distance matrix .npz file :param dist_type: [str] type of distance matrix ("upstream", "downstream" or "updown") - :param out_file: + :param dist_idx_name: [str] name of index to sort dist_matrix by. This is + the name of an array in the distance matrix .npz file + :param segs: [list-like] which segments to prepare the data for + :param out_file: [str] path to save the .npz file to :return: [numpy array] processed adjacency matrix """ adj_matrices = np.load(infile) adj = adj_matrices[dist_type] - adj = sort_dist_matrix(adj, adj_matrices["rowcolnames"], segs=segs) + adj = sort_dist_matrix(adj, adj_matrices[dist_idx_name], segs=segs) adj = np.where(np.isinf(adj), 0, adj) adj = -adj mean_adj = np.mean(adj[adj != 0]) diff --git a/river_dl/rnns.py b/river_dl/rnns.py index 7652364..89062dd 100644 --- a/river_dl/rnns.py +++ b/river_dl/rnns.py @@ -10,7 +10,7 @@ def __init__( ): """ :param hidden_size: [int] the number of hidden units - :param num_tasks: [int] number of tasks (variables to be predicted) + :param num_tasks: [int] number of tasks (variables_to_log to be predicted) :param recurrent_dropout: [float] value between 0 and 1 for the probability of a recurrent element to be zero :param dropout: [float] value between 0 and 1 for the probability of an @@ -59,7 +59,7 @@ def __init__( ): """ :param hidden_size: [int] the number of hidden units - :param num_tasks: [int] number of tasks (variables to be predicted) + :param num_tasks: [int] number of tasks (variables_to_log to be predicted) :param recurrent_dropout: [float] value between 0 and 1 for the probability of a recurrent element to be zero :param dropout: [float] value between 0 and 1 for the probability of an diff --git a/river_dl/tests/generate_test_data.py b/river_dl/tests/generate_test_data.py index 415105f..07e8f5d 100644 --- a/river_dl/tests/generate_test_data.py +++ b/river_dl/tests/generate_test_data.py @@ -1,4 +1,5 @@ import pandas as pd +import os def select_data(df, col, selection): @@ -9,34 +10,45 @@ def select_data(df, col, selection): def sel_date_segs(df, segs, start_date, end_date): + df = df[df["seg_id_nat"].notna()] + df["seg_id_nat"] = df.seg_id_nat.astype(int) df = select_data(df, "date", slice(start_date, end_date)) df = select_data(df, "seg_id_nat", segs) + df = df.rename(columns={"seg_id_nat": "segs_test", "date": "times_test"}) + df.set_index(["segs_test", "times_test"], inplace=True) return df # need to subset this data so it's just two years and two sites. I think # such a dataset should be representative enough to run tests against -dfs = pd.read_feather("../../data/in/uncal_sntemp_input_output_subset.feather") +data_dir = "../../../drb-dl-model/data/in/" +dfs = pd.read_feather( + os.path.join(data_dir, "uncal_sntemp_input_output_subset.feather") +) dfs["date"] = pd.to_datetime(dfs["date"]) dft = pd.read_csv( - "../../data/in/obs_flow_subset.csv", + os.path.join(data_dir, "obs_temp_full.csv"), parse_dates=["date"], infer_datetime_format=True, ) dfq = pd.read_csv( - "../../data/in/obs_temp_subset.csv", + os.path.join(data_dir, "obs_flow_full.csv"), parse_dates=["date"], infer_datetime_format=True, ) -start_date = "2004-09-15" +start_date = "2003-09-15" end_date = "2006-10-15" -segs = ["2012", "2007"] +segs = [2012, 2007] dft = sel_date_segs(dft, segs, start_date, end_date) dfq = sel_date_segs(dfq, segs, start_date, end_date) dfs = sel_date_segs(dfs, segs, start_date, end_date) -dft.to_csv("test_data/obs_temp_full.csv", index=False) -dfq.to_csv("test_data/obs_flow_full.csv", index=False) -dfs.to_feather("test_data/uncal_sntemp_input_output.feather") +dft = dft[["temp_c"]] +dfq = dfq[["discharge_cms"]] + +df_combined = dft.join(dfq) + +df_combined.to_xarray().to_zarr("test_data/obs_temp_flow", mode="w") +dfs.to_xarray().to_zarr("test_data/test_data", mode="w") diff --git a/river_dl/tests/test_data/obs_temp_flow/.zattrs b/river_dl/tests/test_data/obs_temp_flow/.zattrs new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/river_dl/tests/test_data/obs_temp_flow/.zattrs @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/river_dl/tests/test_data/obs_temp_flow/.zgroup b/river_dl/tests/test_data/obs_temp_flow/.zgroup new file mode 100644 index 0000000..3b7daf2 --- /dev/null +++ b/river_dl/tests/test_data/obs_temp_flow/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/river_dl/tests/test_data/obs_temp_flow/discharge_cms/.zarray b/river_dl/tests/test_data/obs_temp_flow/discharge_cms/.zarray new file mode 100644 index 0000000..46295a0 --- /dev/null +++ b/river_dl/tests/test_data/obs_temp_flow/discharge_cms/.zarray @@ -0,0 +1,22 @@ +{ + "chunks": [ + 2, + 1127 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dtype": " 0 - assert seg_2012_sum > 0 - - -def test_exclude_in(): - # try just *including* 2007 (exclude_in.yml) - prepped = PreppedData(exclude_file="test_data/exclude_in.yml") - wgts = prepped.sample_y_obs_wgts.to_array() - seg_2007_sum = wgts.sel(seg_id_nat=2007).sum().values - seg_2012_sum = wgts.sel(seg_id_nat=2012).sum().values - assert seg_2007_sum > 0 - assert seg_2012_sum == 0 - - -def test_exclude_2007(): - # try just excluding 2007 (exclude_2007.yml) - prepped = PreppedData(exclude_file="test_data/exclude_2007.yml") - wgts = prepped.sample_y_obs_wgts.to_array() - seg_2007_sum = wgts.sel(seg_id_nat=2007).sum().values - seg_2012_sum = wgts.sel(seg_id_nat=2012).sum().values - assert seg_2007_sum == 0 - assert seg_2012_sum > 0 - - -def test_exclude_periods(): - # excluding 2007 for second half, 2012 first half(exclude1.yml) - prepped = PreppedData(exclude_file="test_data/exclude1.yml") - wgts = prepped.sample_y_obs_wgts.to_array() - start_date = wgts.date.min() - mid_date = "2005-03-15" - end_date = wgts.date.max() - seg_2012_sum_1st = ( - wgts.sel(seg_id_nat=2012, date=slice(start_date, mid_date)).sum().values - ) - seg_2012_sum_2nd = ( - wgts.sel(seg_id_nat=2012, date=slice(mid_date, end_date)).sum().values - ) - seg_2007_sum_1st = ( - wgts.sel(seg_id_nat=2007, date=slice(start_date, mid_date)).sum().values - ) - seg_2007_sum_2nd = ( - wgts.sel(seg_id_nat=2007, date=slice(mid_date, end_date)).sum().values - ) - assert seg_2007_sum_1st > 0 - assert seg_2007_sum_2nd == 0 - assert seg_2012_sum_1st == 0 - assert seg_2012_sum_2nd > 0 - - -def test_var_weights(): - # try with both variables - prepped = PreppedData() - wgts = prepped.sample_y_obs_wgts.to_array() - flow_sum = wgts.sel(variable="seg_outflow").sum().values - temp_sum = wgts.sel(variable="seg_tave_water").sum().values - assert flow_sum > 0 - assert temp_sum > 0 - - -def test_var_weights_temp_ft(): - # try with just temp in finetuning - prepped = PreppedData(ft_vars=["seg_tave_water"]) - wgts_obs = prepped.sample_y_obs_wgts.to_array() - wgts_pre = prepped.sample_y_pre_wgts.to_array() - flow_sum_pre = wgts_pre.sel(variable="seg_outflow").sum().values - temp_sum_pre = wgts_pre.sel(variable="seg_tave_water").sum().values - flow_sum = wgts_obs.sel(variable="seg_outflow").sum().values - temp_sum = wgts_obs.sel(variable="seg_tave_water").sum().values - assert flow_sum_pre > 0 - assert temp_sum_pre > 0 - assert flow_sum == 0 - assert temp_sum > 0 - - -def test_var_weights_flow_ft(): - # try with just flow in finetuning - prepped = PreppedData(ft_vars=["seg_outflow"]) - wgts_obs = prepped.sample_y_obs_wgts.to_array() - wgts_pre = prepped.sample_y_pre_wgts.to_array() - flow_sum_pre = wgts_pre.sel(variable="seg_outflow").sum().values - temp_sum_pre = wgts_pre.sel(variable="seg_tave_water").sum().values - flow_sum = wgts_obs.sel(variable="seg_outflow").sum().values - temp_sum = wgts_obs.sel(variable="seg_tave_water").sum().values - assert flow_sum_pre > 0 - assert temp_sum_pre > 0 - assert flow_sum > 0 - assert temp_sum == 0 - - -def test_var_weights_temp_ft_pt(): - # try with just temp in finetuning and pretraining - prepped = PreppedData( - ft_vars=["seg_tave_water"], pt_vars=["seg_tave_water"] - ) - wgts_obs = prepped.sample_y_obs_wgts.to_array() - wgts_pre = prepped.sample_y_pre_wgts.to_array() - temp_sum_pre = wgts_pre.sel(variable="seg_tave_water").sum().values - temp_sum = wgts_obs.sel(variable="seg_tave_water").sum().values - assert len(wgts_obs.variable) == 1 - assert len(wgts_pre.variable) == 1 - assert temp_sum_pre > 0 - assert temp_sum > 0 - -def test_var_weights_flow_ft_pt(): - # try with just flow in finetuning and pretraining - prepped = PreppedData(ft_vars=["seg_outflow"], pt_vars=["seg_outflow"]) - wgts_obs = prepped.sample_y_obs_wgts.to_array() - wgts_pre = prepped.sample_y_pre_wgts.to_array() - flow_sum_pre = wgts_pre.sel(variable="seg_outflow").sum().values - flow_sum = wgts_obs.sel(variable="seg_outflow").sum().values - assert len(wgts_obs.variable) == 1 - assert len(wgts_pre.variable) == 1 - assert flow_sum_pre > 0 - assert flow_sum > 0 diff --git a/river_dl/tests/test_train.py b/river_dl/tests/test_train.py new file mode 100644 index 0000000..e1e2839 --- /dev/null +++ b/river_dl/tests/test_train.py @@ -0,0 +1,76 @@ +import pytest +import os +import shutil + +from river_dl import preproc_utils +from river_dl import train +from river_dl import loss_functions + + +def test_finetune_rgcn(): + prepped_data = preproc_utils.prep_all_data( + x_data_file="test_data/test_data", + y_data_file="test_data/obs_temp_flow", + train_start_date="2003-09-15", + train_end_date="2004-09-16", + val_start_date="2004-09-17", + val_end_date="2005-09-18", + test_start_date="2005-09-19", + test_end_date="2006-09-20", + spatial_idx_name="segs_test", + time_idx_name="times_test", + segs=[2007, 2012], + distfile="../../../drb-dl-model/data/in/distance_matrix.npz", + x_vars=["seg_rain", "seg_tave_air"], + y_vars=["temp_c"], + ) + + test_out_dir = 'test_data/test_training_out' + if os.path.exists(test_out_dir): + shutil.rmtree(test_out_dir) + + os.mkdir(test_out_dir) + + model = train.train_model( + io_data=prepped_data, + finetune_epochs=2, + pretrain_epochs=0, + hidden_units=10, + out_dir='test_data/test_training_out', + model_type="rgcn", + seed=2, + dropout=0.12, + loss_func=loss_functions.rmse + ) + + +def test_pretrain_fail(): + prepped_data = preproc_utils.prep_all_data( + x_data_file="test_data/test_data", + y_data_file="test_data/obs_temp_flow", + # pretrain_file="test_data/obs_temp_flow", + train_start_date="2003-09-15", + train_end_date="2004-09-16", + val_start_date="2004-09-17", + val_end_date="2005-09-18", + test_start_date="2005-09-19", + test_end_date="2006-09-20", + spatial_idx_name="segs_test", + time_idx_name="times_test", + x_vars=["seg_rain", "seg_tave_air"], + y_vars=["temp_c"], + ) + + with pytest.raises(KeyError): + model = train.train_model( + io_data=prepped_data, + pretrain_epochs=2, + finetune_epochs=2, + hidden_units=10, + out_dir='test_data/test_training_out', + model_type="lstm", + seed=2, + dropout=0.12, + loss_func=loss_functions.rmse_masked_one_var + ) + diff --git a/river_dl/train.py b/river_dl/train.py index 44d4b84..42476a5 100644 --- a/river_dl/train.py +++ b/river_dl/train.py @@ -54,7 +54,7 @@ def train_model( of a reccurent element to be zero :param dropout: [float] value between 0 and 1 for the probability of an input element to be zero - :param num_tasks: [int] number of tasks (variables to be predicted) + :param num_tasks: [int] number of tasks (variables_to_log to be predicted) :param learning_rate_pre: [float] the pretrain learning rate :param learning_rate_ft: [float] the finetune learning rate :return: [tf model] finetuned model @@ -66,7 +66,6 @@ def train_model( start_time = datetime.datetime.now() io_data = get_data_if_file(io_data) - dist_matrix = io_data["dist_matrix"] n_seg = len(np.unique(io_data["ids_trn"])) if n_seg > 1: @@ -83,6 +82,7 @@ def train_model( dropout=dropout, ) elif model_type == "rgcn": + dist_matrix = io_data["dist_matrix"] model = RGCNModel( hidden_units, num_tasks=num_tasks, @@ -147,9 +147,7 @@ def train_model( # finetune if finetune_epochs > 0: optimizer_ft = tf.optimizers.Adam(learning_rate=learning_rate_ft) - - - + if model_type == "rgcn" and loss_type.lower()=="gw": #extract these for use in the GW loss function temp_index = np.where(io_data['y_vars']=="seg_tave_water")[0] diff --git a/river_dl/train_model_cli.py b/river_dl/train_model_cli.py index 08a1e1b..5e188d4 100644 --- a/river_dl/train_model_cli.py +++ b/river_dl/train_model_cli.py @@ -83,7 +83,7 @@ def get_loss_func_from_str(loss_func_str, lambdas=None): parser.add_argument( "--num_tasks", - help="number of tasks (variables to be predicted)", + help="number of tasks (variables_to_log to be predicted)", default=1, type=int, )