Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow offline diags to work for 2D only outputs #2137

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion workflows/argo/offline-diags.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ spec:
- name: test_data_config
- name: offline-diags-output
- name: report-output
- {name: offline-diags-flags, value: " "}
- {name: no-wandb, value: "false"}
- {name: wandb-project, value: "argo-default"}
- {name: wandb-tags, value: ""}
Expand Down Expand Up @@ -69,7 +70,8 @@ spec:
python -m fv3net.diagnostics.offline.compute \
{{inputs.parameters.ml-model}} \
test_data.yaml \
{{inputs.parameters.offline-diags-output}}
{{inputs.parameters.offline-diags-output}} \
{{inputs.parameters.offline-diags-flags}}

cat << EOF > training.yaml
{{inputs.parameters.training_config}}
Expand Down
3 changes: 3 additions & 0 deletions workflows/argo/train-diags-prog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ spec:
- {name: memory-training, value: 6Gi}
- {name: memory-offline-diags, value: 10Gi}
- {name: training-flags, value: " "}
- {name: offline-diags-flags, value: " "}
- {name: online-diags-flags, value: " "}
- {name: do-prognostic-run, value: "true"}
- {name: no-wandb, value: "false"}
Expand Down Expand Up @@ -115,6 +116,8 @@ spec:
value: "{{inputs.parameters.tag}},{{inputs.parameters.wandb-tags}}"
- name: wandb-group
value: "{{inputs.parameters.wandb-group}}"
- name: offline-diags-flags
value: "{{inputs.parameters.offline-diags-flags}}"
- name: insert-model-urls
when: "{{inputs.parameters.do-prognostic-run}} == true"
dependencies: [resolve-output-url]
Expand Down
58 changes: 48 additions & 10 deletions workflows/diagnostics/fv3net/diagnostics/offline/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import logging
import numpy as np
import os
import sys
from tempfile import NamedTemporaryFile
Expand Down Expand Up @@ -103,6 +104,14 @@ def _get_parser() -> argparse.ArgumentParser:
default=-1,
help=("Optional n_jobs parameter for joblib.parallel when computing metrics."),
)
parser.add_argument(
"--outputs-2d-only",
action="store_true",
help=(
"Use flag if all model outputs are 2D, in which case pressure thickness"
"does not have to be in test dataset."
),
)
return parser


Expand Down Expand Up @@ -144,6 +153,7 @@ def _compute_diagnostics(
target = safe.get_variables(
ds.sel({DERIVATION_DIM_NAME: TARGET_COORD}), full_predicted_vars
)

ds_summary = compute_diagnostics(prediction, target, grid, ds[DELP], n_jobs=n_jobs)

timesteps.append(ds["time"])
Expand Down Expand Up @@ -214,17 +224,26 @@ def transform(ds):
return transform


def _fill_delp():
# filler for delp array that is require as an arg but not used in diagnostics
def transform(ds):
ds[DELP] = xr.DataArray([np.nan, np.nan])
return ds

return transform


def _get_data_mapper_if_exists(config):
if isinstance(config, loaders.BatchesFromMapperConfig):
return config.load_mapper()
else:
return None


def _variables_to_load(model):
vars = list(
set(list(model.input_variables) + list(model.output_variables) + [DELP])
)
def _variables_to_load(model, outputs_2d_only=False):
vars = list(set(list(model.input_variables) + list(model.output_variables)))
if outputs_2d_only is False:
vars.append(DELP)
if "Q2" in model.output_variables:
vars.append("water_vapor_path")
return vars
Expand Down Expand Up @@ -258,12 +277,14 @@ def get_prediction(
config: loaders.BatchesFromMapperConfig,
model: fv3fit.Predictor,
evaluation_resolution: int,
outputs_2d_only: bool,
) -> xr.Dataset:
model_variables = _variables_to_load(model)
model_variables = _variables_to_load(model, outputs_2d_only)

if config.timesteps:
config.timesteps = sorted(config.timesteps)
batches = config.load_batches(model_variables)

batches = config.load_batches(model_variables)
transforms = [_get_predict_function(model, model_variables)]

prediction_resolution = res_from_string(config.res)
Expand All @@ -274,19 +295,33 @@ def get_prediction(
prediction_resolution=prediction_resolution,
)
)
if outputs_2d_only:
transforms.append(_fill_delp())

mapping_function = compose_left(*transforms)
batches = loaders.Map(mapping_function, batches)

concatted_batches = _daskify_sequence(batches)
try:
concatted_batches = _daskify_sequence(batches)
except KeyError as e:
key_err = str(e)
if "pressure_thickness_of_atmospheric_layer" in key_err:
raise KeyError(
"Variable 'pressure_thickness_of_atmospheric_layer' "
"not in dataset. If outputs are 2D and this variable "
"is not needed for diagnostics, include the CLI flag "
"--outputs-2d-only. If outputs are 3D, make sure this "
"variable is present in the test dataset."
)
else:
raise KeyError(key_err)
del batches
return concatted_batches


def _daskify_sequence(batches):
temp_data_dir = temporary_directory()
for i, batch in enumerate(batches):
logger.info(f"Locally caching batch {i+1}/{len(batches)+1}")
logger.info(f"Locally caching batch {i+1}/{len(batches)}")
batch.to_netcdf(os.path.join(temp_data_dir.name, f"{i}.nc"))
dask_ds = xr.open_mfdataset(os.path.join(temp_data_dir.name, "*.nc"))
return dask_ds
Expand Down Expand Up @@ -314,7 +349,10 @@ def main(args):
model = fv3fit.DerivedModel(model, derived_output_variables=["Q2"])

ds_predicted = get_prediction(
config=config, model=model, evaluation_resolution=evaluation_grid.sizes["x"]
config=config,
model=model,
evaluation_resolution=evaluation_grid.sizes["x"],
outputs_2d_only=args.outputs_2d_only,
)

output_data_yaml = os.path.join(args.output_path, "data_config.yaml")
Expand Down