Skip to content

Commit

Permalink
Update time_utils.py and add tests
Browse files Browse the repository at this point in the history
Update get_agg_dates to work with multi-dim data, remove depreciated loffset parameter from resample and add unit test. Speed up min_tsteps option in temporal_aggregation by reducing the values loaded to 2 and removing the xr.where search. This assumes only the first and last aggregation period needs to be checked and they will be consistent across all other dimensions (mainly used to ensure full time periods are aggregated when target_freq != YE-DEC). Add unit tests for min_tsteps option in temporal_aggregation. Fix mask.values = error in select_time_period.
  • Loading branch information
stellema committed Jul 15, 2024
1 parent cacbb06 commit f83dcf5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 29 deletions.
57 changes: 54 additions & 3 deletions unseen/tests/test_time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import numpy as np
from xarray.coding.times import cftime_to_nptime

from unseen.time_utils import (
select_time_period,
)
from unseen.time_utils import select_time_period, temporal_aggregation


@pytest.mark.parametrize("example_da_forecast", ["numpy"], indirect=True)
Expand All @@ -30,3 +28,56 @@ def test_select_time_period(example_da_forecast, add_nans, data_object):

assert min_time >= np.datetime64(PERIOD[0])
assert max_time <= np.datetime64(PERIOD[1])


@pytest.mark.parametrize("example_da_timeseries", ["numpy"], indirect=True)
def test_temporal_aggregation_agg_dates(example_da_timeseries):
"""Test temporal_aggregation using agg_dates."""
# Monotonically increasing time series
data = example_da_timeseries
ds = data.to_dataset(name="var")

ds_resampled = temporal_aggregation(
ds,
target_freq="ME",
input_freq="D",
agg_method="max",
variables=["var"],
season=None,
reset_times=False,
min_tsteps=None,
agg_dates=True,
time_dim="time",
)
event_time = ds_resampled["event_time"].astype(dtype="datetime64[ns]")
# Monthly maximum should be the last day of each month
assert np.all(event_time.dt.day >= 28)


@pytest.mark.parametrize("example_da_timeseries", ["numpy"], indirect=True)
def test_temporal_aggregation_min_tsteps(example_da_timeseries):
"""Test temporal_aggregation using min_tsteps."""
data = example_da_timeseries
ds = data.to_dataset(name="var")
# Remove days from first & last month and test the months are removed
ds = ds.isel(time=slice(5, -5))

variables = ["var"]
target_freq = "ME"
time_dim = "time"
min_tsteps = 28
counts = ds[variables[0]].resample(time=target_freq).count(dim=time_dim).load()

ds_resampled = temporal_aggregation(
ds,
target_freq,
input_freq="D",
agg_method="max",
variables=variables,
season=None,
reset_times=False,
min_tsteps=min_tsteps,
agg_dates=False,
time_dim=time_dim,
)
assert np.all((counts >= min_tsteps).sum(time_dim) == len(ds_resampled.time))
49 changes: 23 additions & 26 deletions unseen/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def get_agg_dates(ds, var, target_freq, agg_method, time_dim="time"):
Parameters
----------
ds : xarray Dataset
A time resampled dataset
ds : xarray.Dataset
Dataset to be resampled
var : str
A variable in the dataset
target_freq : str
Expand All @@ -29,28 +29,19 @@ def get_agg_dates(ds, var, target_freq, agg_method, time_dim="time"):
Returns
-------
event_datetimes_str : numpy.ndarray
Array of event dates
Event date strings for the resampled array
"""
ds_arg = ds[var].resample(time=target_freq, label="left")
if agg_method == "max":
dates = [da.idxmax(time_dim) for _, da in ds_arg]
elif agg_method == "min":
dates = [da.idxmin(time_dim) for _, da in ds_arg]

reduce_funcs = {"min": np.nanargmin, "max": np.nanargmax}

ds_arg = ds.resample(
time=target_freq, label="left", loffset=datetime.timedelta(days=1)
).reduce(reduce_funcs[agg_method], dim=time_dim)
time_diffs = ds_arg[var].values.astype("timedelta64[D]")
str_time_axis = [time.strftime("%Y-%m-%d") for time in ds_arg[time_dim].values]
datetime_time_axis = np.array(str_time_axis, dtype="datetime64")
assert time_diffs.ndim <= 2
if time_diffs.ndim == 2:
other_dims = list(ds_arg[var].dims)
other_dims.remove(time_dim)
other_dim_name = other_dims[0]
other_dim_index = ds_arg[var].dims.index(other_dim_name)
datetime_time_axis = np.expand_dims(datetime_time_axis, axis=other_dim_index)
event_datetimes_np = datetime_time_axis + time_diffs
event_datetimes_str = np.datetime_as_string(event_datetimes_np)
dates = xr.concat(dates, dim=time_dim)
event_datetimes_str = dates.load().dt.strftime("%Y-%m-%d")
event_datetimes_str = event_datetimes_str.astype(dtype=str)

return event_datetimes_str
return event_datetimes_str.values


def temporal_aggregation(
Expand Down Expand Up @@ -162,9 +153,14 @@ def temporal_aggregation(
assert ds[time_dim].values[0] == start_time

if min_tsteps:
for var in variables:
ds[var] = ds[var].where(counts.values >= min_tsteps)
ds = ds.dropna(dim=time_dim)
# Drop first and last time points with insufficient time steps
counts = counts.isel({time_dim: [0, -1]})
# Select the minimum of the non-time dimensions
counts = counts.min([dim for dim in ds[variables[0]].dims if dim != time_dim])
if counts[0] < min_tsteps:
ds = ds.isel({time_dim: slice(1, None)})
if counts[-1] < min_tsteps:
ds = ds.isel({time_dim: slice(None, -1)})

if reindexed:
ds = ds.compute()
Expand Down Expand Up @@ -219,8 +215,9 @@ def _inbounds(t, bnds):
start=start, end=stop, periods=2, freq=None, calendar=calendar
)
mask_values = _vinbounds(time_values, time_bounds)
mask = ds[time_name].copy()
mask.values = mask_values
mask = xr.DataArray(
mask_values, dims=ds[time_name].dims, coords=ds[time_name].coords
)
selection = ds.where(mask, drop=True)
else:
raise ValueError("No time axis for masking")
Expand Down

0 comments on commit f83dcf5

Please sign in to comment.