Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #120 from jsadler2/103-flex-io-names
Browse files Browse the repository at this point in the history
103 flexible input and output names
  • Loading branch information
SimonTopp authored Aug 17, 2021
2 parents 715df1b + b27c9c2 commit a7629eb
Show file tree
Hide file tree
Showing 102 changed files with 1,064 additions and 746 deletions.
28 changes: 15 additions & 13 deletions Snakefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from river_dl.preproc_utils import prep_data
from river_dl.preproc_utils import prep_all_data
from river_dl.evaluate import combined_metrics
from river_dl.postproc_utils import plot_obs
from river_dl.predict import predict_from_io_data
Expand All @@ -20,15 +20,20 @@ rule all:

rule prep_io_data:
input:
config['obs_temp'],
config['obs_flow'],
config['sntemp_file'],
config['obs_file'],
config['dist_matrix'],
output:
"{outdir}/prepped.npz"
run:
prep_data(input[0], input[1], input[2], input[3],
prep_all_data(
x_data_file=input[0],
pretrain_file=input[0],
y_data_file=input[1],
distfile=input[2],
x_vars=config['x_vars'],
y_vars_pretrain=['seg_tave_water', 'seg_outflow'],
y_vars_finetune=['temp_c', 'discharge_cms'],
catch_prop_file=None,
exclude_file=None,
train_start_date=config['train_start_date'],
Expand All @@ -37,8 +42,7 @@ rule prep_io_data:
val_end_date=config['val_end_date'],
test_start_date=config['test_start_date'],
test_end_date=config['test_end_date'],
primary_variable=config['primary_variable'],
log_q=False, segs=None,
segs=None,
out_file=output[0])


Expand Down Expand Up @@ -90,7 +94,7 @@ rule make_predictions:
predict_from_io_data(model_type='rgcn', model_weights_dir=model_dir,
hidden_size=config['hidden_size'], io_data=input[1],
partition=wildcards.partition, outfile=output[0],
logged_q=False, num_tasks=2)
num_tasks=2)


def get_grp_arg(wildcards):
Expand All @@ -106,8 +110,7 @@ def get_grp_arg(wildcards):

rule combine_metrics:
input:
config['obs_temp'],
config['obs_flow'],
config['obs_file'],
"{outdir}/trn_preds.feather",
"{outdir}/val_preds.feather"
output:
Expand All @@ -116,10 +119,9 @@ rule combine_metrics:
params:
grp_arg = get_grp_arg
run:
combined_metrics(obs_temp=input[0],
obs_flow=input[1],
pred_trn=input[2],
pred_val=input[3],
combined_metrics(obs_file=input[0],
pred_trn=input[1],
pred_val=input[2],
group=params.grp_arg,
outfile=output[0])

Expand Down
14 changes: 6 additions & 8 deletions config.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Input files
obs_flow: "data_DRB/obs_flow_subset"
obs_temp: "data_DRB/obs_temp_subset"
sntemp_file: "data_DRB/uncal_sntemp_input_output_subset"
dist_matrix: "data_DRB/distance_matrix_subset.npz"
obs_file: "../drb-dl-model/data/in/obs_flow_temp_subset"
sntemp_file: "../drb-dl-model/data/in/uncal_sntemp_input_output_subset"
dist_matrix: "../drb-dl-model/data/in/distance_matrix_subset.npz"

out_dir: "output_DRB_test"
out_dir: "out_103_test"

code_dir: "river_dl"

x_vars: ["seg_rain", "seg_tave_air", "seginc_swrad", "seg_length", "seginc_potet", "seg_slope", "seg_humid", "seg_elev"]
primary_variable: "temp"

lambdas: [100,100]

Expand All @@ -29,6 +27,6 @@ test_end_date:
- '1985-09-30'
- '2021-09-30'

pt_epochs: 20
ft_epochs: 10
pt_epochs: 2
ft_epochs: 2
hidden_size: 20
132 changes: 84 additions & 48 deletions river_dl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
def filter_negative_preds(y_true, y_pred):
"""
filters out negative predictions and prints a warning if there are >5% of predictions as negative
:param y_true: [array-like] observed y values
:param y_pred: [array-like] predicted y values
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:return: [array-like] filtered data
"""
# print a warning if there are a lot of negatives
Expand All @@ -29,8 +29,8 @@ def filter_negative_preds(y_true, y_pred):
def rmse_logged(y_true, y_pred):
"""
compute the rmse of the logged data
:param y_true: [array-like] observed y values
:param y_pred: [array-like] predicted y values
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:return: [float] the rmse of the logged data
"""
y_true, y_pred = filter_negative_preds(y_true, y_pred)
Expand All @@ -40,8 +40,8 @@ def rmse_logged(y_true, y_pred):
def nse_logged(y_true, y_pred):
"""
compute the rmse of the logged data
:param y_true: [array-like] observed y values
:param y_pred: [array-like] predicted y values
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:return: [float] the nse of the logged data
"""
y_true, y_pred = filter_negative_preds(y_true, y_pred)
Expand All @@ -52,8 +52,8 @@ def filter_by_percentile(y_true, y_pred, percentile, less_than=True):
"""
filter an array by a percentile of the observations. The data less than
or greater than if `less_than=False`) will be changed to NaN
:param y_true: [array-like] observed y values
:param y_pred: [array-like] predicted y values
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:param percentile: [number] percentile number 0-100
:param less_than: [bool] whether you want the data *less than* the
percentile. If False, the data greater than the percentile will remain.
Expand All @@ -72,8 +72,8 @@ def filter_by_percentile(y_true, y_pred, percentile, less_than=True):
def percentile_metric(y_true, y_pred, metric, percentile, less_than=True):
"""
compute an evaluation metric for a specified percentile of the observations
:param y_true: [array-like] observed y values
:param y_pred: [array-like] predicted y values
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:param metric: [function] metric function
:param percentile: [number] percentile number 0-100
:param less_than: [bool] whether you want the data *less than* the
Expand Down Expand Up @@ -134,69 +134,93 @@ def calc_metrics(df):
return pd.Series(metrics)


def overall_metrics(
pred_file, obs_file, variable, partition, group=None, outfile=None
def partition_metrics(
pred_file,
obs_file,
partition,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
group=None,
outfile=None
):
"""
calculate metrics for a certain group (or no group at all) for a given
partition and variable
:param pred_file: [str] path to predictions feather file
:param obs_file: [str] path to observations zarr file
:param variable: [str] variable for which the metrics are being calculated
:param partition: [str] data partition for which metrics are calculated
:param spatial_idx_name: [str] name of column that is used for spatial
index (e.g., 'seg_id_nat')
:param time_idx_name: [str] name of column that is used for temporal index
(usually 'time')
:param group: [str or list] which group the metrics should be computed for.
Currently only supports 'seg_id_nat' (segment-wise metrics), 'month'
(month-wise metrics), ['seg_id_nat', 'month'] (metrics broken out by segment
and month), and None (everything is left together)
:param outfile: [str] file where the metrics should be written
:return: [pd dataframe] the condensed metrics
"""
data = fmt_preds_obs(pred_file, obs_file, variable)
data.reset_index(inplace=True)
if not group:
metrics = calc_metrics(data)
# need to convert to dataframe and transpose so it looks like the others
metrics = pd.DataFrame(metrics).T
elif group == "seg_id_nat":
metrics = data.groupby("seg_id_nat").apply(calc_metrics).reset_index()
elif group == "month":
metrics = (
data.groupby(data["date"].dt.month)
var_data = fmt_preds_obs(pred_file, obs_file, spatial_idx_name,
time_idx_name)
var_metrics_list = []

for data_var, data in var_data.items():
data.reset_index(inplace=True)
if not group:
metrics = calc_metrics(data)
# need to convert to dataframe and transpose so it looks like the
# others
metrics = pd.DataFrame(metrics).T
elif group == "seg_id_nat":
metrics = data.groupby(spatial_idx_name).apply(calc_metrics).reset_index()
elif group == "month":
metrics = (
data.groupby(
data[time_idx_name].dt.month)
.apply(calc_metrics)
.reset_index()
)
elif group == ["seg_id_nat", "month"]:
metrics = (
data.groupby([data["date"].dt.month, "seg_id_nat"])
)
elif group == ["seg_id_nat", "month"]:
metrics = (
data.groupby(
[data[time_idx_name].dt.month,
spatial_idx_name])
.apply(calc_metrics)
.reset_index()
)
else:
raise ValueError("group value not valid")
metrics["variable"] = variable
metrics["partition"] = partition
)
else:
raise ValueError("group value not valid")

metrics["variable"] = data_var
metrics["partition"] = partition
var_metrics_list.append(metrics)
var_metrics = pd.concat(var_metrics_list)
if outfile:
metrics.to_csv(outfile, header=True, index=False)
return metrics
var_metrics.to_csv(outfile, header=True, index=False)
return var_metrics


def combined_metrics(
obs_temp,
obs_flow,
obs_file,
pred_trn=None,
pred_val=None,
pred_tst=None,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
group=None,
outfile=None,
):
"""
calculate the metrics for flow and temp and training and test sets for a
given grouping
:param obs_file: [str] path to observations zarr file
:param pred_trn: [str] path to training prediction feather file
:param pred_val: [str] path to validation prediction feather file
:param pred_tst: [str] path to testing prediction feather file
:param obs_temp: [str] path to observations temperature zarr file
:param obs_flow: [str] path to observations flow zarr file
:param spatial_idx_name: [str] name of column that is used for spatial
index (e.g., 'seg_id_nat')
:param time_idx_name: [str] name of column that is used for temporal index
(usually 'time')
:param group: [str or list] which group the metrics should be computed for.
Currently only supports 'seg_id_nat' (segment-wise metrics), 'month'
(month-wise metrics), ['seg_id_nat', 'month'] (metrics broken out by segment
Expand All @@ -206,17 +230,29 @@ def combined_metrics(
"""
df_all = []
if pred_trn:
trn_temp = overall_metrics(pred_trn, obs_temp, "temp", "trn", group)
trn_flow = overall_metrics(pred_trn, obs_flow, "flow", "trn", group)
df_all.extend([trn_temp, trn_flow])
trn_metrics = partition_metrics(pred_file=pred_trn,
obs_file=obs_file,
partition="trn",
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name,
group=group)
df_all.extend([trn_metrics])
if pred_val:
val_temp = overall_metrics(pred_val, obs_temp, "temp", "val", group)
val_flow = overall_metrics(pred_val, obs_flow, "flow", "val", group)
df_all.extend([val_temp, val_flow])
val_metrics = partition_metrics(pred_file=pred_val,
obs_file=obs_file,
partition="val",
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name,
group=group)
df_all.extend([val_metrics])
if pred_tst:
tst_temp = overall_metrics(pred_tst, obs_temp, "temp", "tst", group)
tst_flow = overall_metrics(pred_tst, obs_flow, "flow", "tst", group)
df_all.extend([tst_temp, tst_flow])
tst_metrics = partition_metrics(pred_file=pred_tst,
obs_file=obs_file,
partition="tst",
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name,
group=group)
df_all.extend([tst_metrics])
df_all = pd.concat(df_all, axis=0)
if outfile:
df_all.to_csv(outfile, index=False)
Expand Down
10 changes: 5 additions & 5 deletions river_dl/gw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def prep_annual_signal_data(
:param io_data_file: [str] the prepped data file
:param train_start_date, train_end_date, val_start_date,val_end_date,test_start_date,test_end_date: [str]
the start and end dates of the training, validation, and testing periods
:param gwVarList: [str] list of groundwater-relevant variables
:param gwVarList: [str] list of groundwater-relevant variables_to_log
:param out_file: [str] file to where the values will be written
:param water_temp_pbm_col: str with the column name of the process-based model predicted water temperatures in degrees C
:param water_temp_obs_col: str with the column name of the observed water temperatures in degrees C
Expand Down Expand Up @@ -288,7 +288,7 @@ def prep_annual_signal_data(
GW_trn_scale['Ar_obs'] = (GW_trn['Ar_obs']-np.nanmean(GW_trn['Ar_obs']))/np.nanstd(GW_trn['Ar_obs'])
GW_trn_scale['delPhi_obs'] = (GW_trn['delPhi_obs']-np.nanmean(GW_trn['delPhi_obs']))/np.nanstd(GW_trn['delPhi_obs'])

#add the GW data to the y dataset
#add the GW data to the y_dataset dataset
preppedData = np.load(io_data_file)
data = {k:v for k, v in preppedData.items() if not k.startswith("GW")}
data['GW_trn_reshape']=make_GW_dataset(GW_trn_scale,obs_trn,gwVarList)
Expand Down Expand Up @@ -363,7 +363,7 @@ def make_GW_dataset (GW_data,x_data,varList):
prepares a GW-relevant dataset for the GW loss function that can be combined with y_true
:param GW_data: [dataframe] dataframe of annual temperature signal properties by segment
:param x_data: [str] observation dataset
:param varList: [str] variables to keep in the final dataset
:param varList: [str] variables_to_log to keep in the final dataset
:returns: GW dataset that is reshaped to match the shape of the first 2 dimensions of the y_true dataset
"""
#make a dataframe with all combinations of segment and date and then join the annual temperature signal properties dataframe to it
Expand Down Expand Up @@ -473,12 +473,12 @@ def calc_gw_metrics(trnFile,tstFile,valFile,outFile,figFile1, figFile2, pbm_name
for x in range(len(thisData['{}_obs'.format(thisMetric)])):
thisColor = colorDict[thisData.group[x]]
ax.plot([thisData['{}_obs'.format(thisMetric+"_low")][x],thisData['{}_obs'.format(thisMetric+"_high")][x]],[thisData['{}_pred'.format(thisMetric)][x],thisData['{}_pred'.format(thisMetric)][x]], color=thisColor)
# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y=thisData['{}_pred'.format(thisMetric)],label="RGCN",color="blue")
# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y_dataset=thisData['{}_pred'.format(thisMetric)],label="RGCN",color="blue")
for thisGroup in np.unique(thisData['group']):
thisColor = colorDict[thisGroup]
ax.scatter(x=thisData.loc[thisData.group==thisGroup,'{}_obs'.format(thisMetric)],y=thisData.loc[thisData.group==thisGroup,'{}_pred'.format(thisMetric)],label="RGCN - %s"%thisGroup,color=thisColor)

# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y=thisData['{}_sntemp'.format(thisMetric)],label="SNTEMP",color="red")
# ax.scatter(x=thisData['{}_obs'.format(thisMetric)],y_dataset=thisData['{}_sntemp'.format(thisMetric)],label="SNTEMP",color="red")
for i, label in enumerate(thisData.seg_id_nat):
ax.annotate(int(label), (thisData['{}_obs'.format(thisMetric)][i],thisData['{}_pred'.format(thisMetric)][i]))
if thisFig==1:
Expand Down
2 changes: 1 addition & 1 deletion river_dl/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def multitask_kge(lambdas):

def multitask_loss(lambdas, loss_func):
"""
calculate a weighted multi-task loss for a given number of variables with a
calculate a weighted multi-task loss for a given number of variables_to_log with a
given loss function
:param lambdas: [array-like float] The factor that losses will be
multiplied by before being added together.
Expand Down
Loading

0 comments on commit a7629eb

Please sign in to comment.