Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for spatial data #74

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: unseen-test
channels:
- conda-forge
dependencies:
- cartopy
- cftime
- cmdline_provenance
- dask-core
Expand All @@ -18,6 +19,7 @@ dependencies:
- seaborn
- xarray
- xclim>=0.39.0
- xesmf
- xskillscore
- zarr
- pip:
Expand Down
73 changes: 62 additions & 11 deletions unseen/bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import argparse
import operator

import xarray as xr

from . import array_handling
Expand All @@ -20,15 +19,19 @@ def get_bias(
ensemble_dim="ensemble",
init_dim="init_date",
lead_dim="lead_time",
lat_dim="lat",
lon_dim="lon",
by_lead=False,
regrid="obs",
regrid_method="conservative",
):
"""Calculate forecast bias.

Parameters
----------
fcst : xarray DataArray
fcst : xarray.DataArray
Forecast array with ensemble, initial date and lead time dimensions
obs : xarray DataArray
obs : xarray.DataArray
Observational array with time dimension
method : {'additive', 'multiplicative'}
Bias removal method
Expand All @@ -42,12 +45,21 @@ def get_bias(
Name of the initial date dimension in fcst
lead_dim: str, default 'lead_time'
Name of the lead time dimension in fcst
lat_dim: str, default 'lat'
Name of the latitude dimension in fcst and obs (if regridding)
lon_dim: str, default 'lon'
Name of the longitude dimension in fcst and obs (if regridding)
by_lead: bool, default False
Calculate bias for each lead time separately
regrid: {None, 'obs', 'fcst'}, default 'obs'
Regrid observational data to model grid (or vice versa)
regrid_method: {'conservative', 'bilinear', 'nearest_s2d', 'nearest_d2s'}, default 'conservative'
Regriding method (see xesmf.Regridder)


Returns
-------
bias : xarray DataArray
bias : xarray.DataArray

Raises
------
Expand All @@ -63,10 +75,12 @@ def get_bias(

fcst_clim = time_utils.get_clim(
fcst,
fcst_ave_dims,
dims=fcst_ave_dims,
time_period=time_period,
groupby_init_month=True,
time_name="time",
)

obs_stacked = array_handling.stack_by_init_date(
obs,
init_dates=fcst[init_dim],
Expand All @@ -80,6 +94,25 @@ def get_bias(
groupby_init_month=True,
)

# Check fcst dimensions include lat_dim and lon_dim
if set({lat_dim, lon_dim}).issubset(fcst_clim.dims):
# Check lat and lon coordinates match
if not all(
[fcst_clim[dim].equals(obs_clim[dim]) for dim in [lat_dim, lon_dim]]
):
if regrid == "obs":
obs_clim = general_utils.regrid(
obs_clim, fcst_clim, method=regrid_method
)
elif regrid == "fcst":
fcst_clim = general_utils.regrid(
fcst_clim, obs_clim, method=regrid_method
)
else:
raise ValueError(
"Forecast and observational data coordinates do not match. Consider regridding."
)

with xr.set_options(keep_attrs=True):
if method == "additive":
bias = fcst_clim - obs_clim
Expand All @@ -100,9 +133,9 @@ def remove_bias(fcst, bias, method, init_dim="init_date"):

Parameters
----------
fcst : xarray DataArray
fcst : xarray.DataArray
Forecast array with initial date and lead time dimensions
bias : xarray DataArray
bias : xarray.DataArray
Bias array
method : {'additive', 'multiplicative'}
Bias removal method
Expand All @@ -111,7 +144,7 @@ def remove_bias(fcst, bias, method, init_dim="init_date"):

Returns
-------
fcst_bc : xarray DataArray
fcst_bc : xarray.DataArray
Bias corrected forecast array

Raises
Expand All @@ -128,7 +161,7 @@ def remove_bias(fcst, bias, method, init_dim="init_date"):
raise ValueError(f"Unrecognised bias removal method {method}")

with xr.set_options(keep_attrs=True):
fcst_bc = op(fcst.groupby(f"{init_dim}.month"), bias).drop("month")
fcst_bc = op(fcst.groupby(f"{init_dim}.month"), bias).drop_vars("month")

fcst_bc.attrs["bias_correction_method"] = bias.attrs["bias_correction_method"]
try:
Expand Down Expand Up @@ -190,6 +223,19 @@ def _parse_command_line():
default=False,
help="Remove bias for each lead time separately [default=False]",
)
parser.add_argument(
"--regrid",
choices=("obs", "fcst"),
default="obs",
help="Regrid observational or forecast data if they are on different grids[default=obs]",
)
parser.add_argument(
"--regrid_method",
choices=("conservative", "bilinear", "nearest_s2d", "nearest_d2s"),
default="conservative",
help="Regriding method for observational or forecast data [default=conservative]",
)

args = parser.parse_args()

return args
Expand All @@ -204,7 +250,8 @@ def _main():
da_obs = ds_obs[args.var]

ds_fcst = fileio.open_dataset(args.fcst_file, variables=[args.var])
da_fcst = ds_fcst[args.var]
da_fcst = ds_fcst[args.var].load()

if args.min_lead:
da_fcst = da_fcst.where(da_fcst["lead_time"] >= args.min_lead)

Expand All @@ -215,10 +262,14 @@ def _main():
time_period=args.base_period,
time_rounding=args.rounding_freq,
by_lead=args.by_lead,
regrid=args.regrid,
regrid_method=args.regrid_method,
)
da_fcst_bc = remove_bias(da_fcst, bias, args.method)

ds_fcst_bc = da_fcst_bc.to_dataset()
ds_fcst_bc = da_fcst_bc.to_dataset(name=args.var)
ds_fcst_bc.attrs.update(ds_fcst.attrs)

infile_logs = {
args.fcst_file: ds_fcst.attrs["history"],
args.obs_file: ds_obs.attrs["history"],
Expand Down
2 changes: 1 addition & 1 deletion unseen/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _fix_metadata(ds, metadata_file):

if "round_coords" in metadata_dict:
for coord in metadata_dict["round_coords"]:
ds = ds.assign_coords({coord: ds[coord].round(decimals=10)})
ds = ds.assign_coords({coord: ds[coord].round(decimals=6)})

if "units" in metadata_dict:
for var, units in metadata_dict["units"].items():
Expand Down
48 changes: 46 additions & 2 deletions unseen/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from matplotlib.ticker import AutoMinorLocator
import matplotlib.pyplot as plt
import numpy as np
import xclim
from xarray import Dataset
from xclim.core import units
from xesmf import Regridder


class store_dict(argparse.Action):
Expand Down Expand Up @@ -61,7 +63,7 @@ def convert_units(da, target_units):
da.attrs["units"] = xclim_unit_check[da.units]

try:
da = xclim.units.convert_units_to(da, target_units)
da = units.convert_units_to(da, target_units)
except Exception as e:
in_precip_kg = da.attrs["units"] == "kg m-2 s-1"
out_precip_mm = target_units in ["mm d-1", "mm day-1"]
Expand All @@ -74,6 +76,48 @@ def convert_units(da, target_units):
return da


def regrid(ds, ds_grid, method="conservative", **kwargs):
"""Regrid `ds` to the grid of `ds_grid` using xESMF.

Parameters
----------
ds : Union[xarray.DataArray, xarray.Dataset]
Input data
ds_grid : Union[xarray.DataArray, xarray.Dataset]
Target grid.
method : {"conservative", "bilinear", "nearest_s2d", "nearest_d2s"}, default "conservative"
Regridding method
**kwargs
Additional keyword arguments for xESMF.Regridder

Returns
-------
ds_regrid : Union[xarray.DataArray, xarray.Dataset]
Regridded xarray.DataArray or xarray.Dataset

Notes
-----
- The input and target grids should have the same coordinate names.
- Recommended using the "conservative" method for regridding from fine to course and "bilinear" for the opposite.
"""
# Copy attributes
global_attrs = ds.attrs
if isinstance(ds, Dataset):
var_attrs = {var: ds[var].attrs for var in ds.data_vars}

# Regrid data
regridder = Regridder(ds, ds_grid, method, **kwargs)
ds_regrid = regridder(ds)

# Update regridded data attributes
ds_regrid.attrs.update(global_attrs)
if isinstance(ds_regrid, Dataset):
for var in ds_regrid.data_vars:
ds_regrid[var].attrs.update(var_attrs[var])

return ds_regrid


def plot_timeseries_scatter(
da,
da_obs=None,
Expand Down
Loading
Loading