diff --git a/mesmer/prototype/__init__.py b/mesmer/prototype/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mesmer/prototype/calibrate.py b/mesmer/prototype/calibrate.py new file mode 100644 index 00000000..7bb646cb --- /dev/null +++ b/mesmer/prototype/calibrate.py @@ -0,0 +1,142 @@ +import abc + +import numpy as np +import sklearn.linear_model +import statsmodels.tsa.ar_model +import xarray as xr + + +class MesmerCalibrateBase(metaclass=abc.ABCMeta): + """ + Abstract base class for calibration + """ + + +class MesmerCalibrateTargetPredictor(MesmerCalibrateBase): + @abc.abstractmethod + def calibrate(self, target, predictor, **kwargs): + """ + [TODO: update this based on however LinearRegression.calibrate's docs + end up looking] + """ + + +class LinearRegression(MesmerCalibrateTargetPredictor): + """ + following + + https://github.com/MESMER-group/mesmer/blob/d73e8f521a2e1d081a48b775ba14dd764cb671e8/mesmer/calibrate_mesmer/train_lt.py#L165 + + All the lines above and below that line are basically just data + preparation, which makes it very hard to see what the model actually is + """ + + @staticmethod + def _regress_single_group(target_point, predictor, weights=None): + # this is the method that actually does the regression + args = [predictor.T, target_point.reshape(-1, 1)] + if weights is not None: + args.append(weights) + reg = sklearn.linear_model.LinearRegression().fit(*args) + out_array = np.concatenate([reg.intercept_, *reg.coef_]) + + return out_array + + def calibrate( + self, + target_flattened, + predictors_flattened, + stack_coord_name, + predictor_name="predictor", + weights=None, + predictor_temporary_name="__pred_store__", + ): + """ + TODO: redo docstring + """ + if predictor_name not in predictors_flattened.dims: + raise AssertionError(f"{predictor_name} not in {predictors_flattened.dims}") + + if predictor_temporary_name in predictors_flattened.dims: + raise AssertionError( + f"{predictor_temporary_name} already in {predictors_flattened.dims}, choose a different temporary name" + ) + + res = xr.apply_ufunc( + self._regress_single_group, + target_flattened, + predictors_flattened, + input_core_dims=[[stack_coord_name], [predictor_name, stack_coord_name]], + output_core_dims=((predictor_temporary_name,),), + vectorize=True, + kwargs=dict(weights=weights), + ) + + # assuming that predictor's names are in the 'variable' coordinate + predictors_plus_intercept_order = ["intercept"] + list( + predictors_flattened["variable"].values + ) + res = res.assign_coords( + {predictor_temporary_name: predictors_plus_intercept_order} + ).rename({predictor_temporary_name: predictor_name}) + + return res + + +class MesmerCalibrateTarget(MesmerCalibrateBase): + @abc.abstractmethod + def calibrate(self, target, **kwargs): + """ + [TODO: update this based on however LinearRegression.calibrate's docs + end up looking] + """ + + @staticmethod + def _check_target_is_one_dimensional(target, return_numpy_values): + if len(target.dims) > 1: + raise AssertionError(f"More than one dimension, found {target.dims}") + + if not return_numpy_values: + return None + + return target.dropna(dim=target.dims[0]).values + + +class AutoRegression1DOrderSelection(MesmerCalibrateTarget): + def calibrate( + self, + target, + maxlag=12, + ic="bic", + ): + target_numpy = self._check_target_is_one_dimensional( + target, return_numpy_values=True + ) + + calibrated = statsmodels.tsa.ar_model.ar_select_order( + target_numpy, maxlag=maxlag, ic=ic, old_names=False + ) + + return calibrated.ar_lags + + +class AutoRegression1D(MesmerCalibrateTarget): + def calibrate( + self, + target, + order, + ): + target_numpy = self._check_target_is_one_dimensional( + target, return_numpy_values=True + ) + + calibrated = statsmodels.tsa.ar_model.AutoReg( + target_numpy, lags=order, old_names=False + ).fit() + + return { + "intercept": calibrated.params[0], + "lag_coefficients": calibrated.params[1:], + # I don't know what this is so a better name could probably be chosen + "standard_innovations": np.sqrt(calibrated.sigma2), + } diff --git a/mesmer/prototype/calibrate_multiple.py b/mesmer/prototype/calibrate_multiple.py new file mode 100644 index 00000000..4b12f6b6 --- /dev/null +++ b/mesmer/prototype/calibrate_multiple.py @@ -0,0 +1,306 @@ +import numpy as np +import pandas as pd +import scipy.stats +import xarray as xr + +import mesmer + +from .calibrate import AutoRegression1D, AutoRegression1DOrderSelection + + +def _get_predictor_dims(predictors): + predictors_dims = {k: v.dims for k, v in predictors.items()} + predictors_dims_unique = set(predictors_dims.values()) + if len(predictors_dims_unique) > 1: + raise AssertionError( + f"Dimensions of predictors are not all the same, we have: {predictors_dims}" + ) + + return list(predictors_dims_unique)[0] + + +def _get_stack_coord_name(inp_array): + stack_coord_name = "stacked_coord" + if stack_coord_name in inp_array.dims: + stack_coord_name = "memser_stacked_coord" + + if stack_coord_name in inp_array.dims: + raise NotImplementedError("You have dimensions we can't safely unstack yet") + + return stack_coord_name + + +def _check_coords_match(obj, obj_other, check_coord): + coords_match = obj.coords[check_coord].equals(obj_other.coords[check_coord]) + if not coords_match: + raise AssertionError(f"{check_coord} is not the same on {obj} and {obj_other}") + + +def _flatten(inp, dims_to_flatten): + stack_coord_name = _get_stack_coord_name(inp) + inp_flat = inp.stack({stack_coord_name: dims_to_flatten}).dropna(stack_coord_name) + + return inp_flat, stack_coord_name + + +def _flatten_predictors(predictors, dims_to_flatten, stack_coord_name): + predictors_flat = [] + for v, vals in predictors.items(): + if stack_coord_name in vals.dims: + raise AssertionError(f"{stack_coord_name} is already in {vals.dims}") + + vals_flat = vals.stack({stack_coord_name: dims_to_flatten}).dropna( + stack_coord_name + ) + vals_flat.name = v + predictors_flat.append(vals_flat) + + out = xr.merge(predictors_flat).to_stacked_array( + "predictor", sample_dims=[stack_coord_name] + ) + + return out + + +def flatten_predictors_and_target(predictors, target): + dims_to_flatten = _get_predictor_dims(predictors) + stack_coord_name = _get_stack_coord_name(target) + + target_flattened = target.stack({stack_coord_name: dims_to_flatten}).dropna( + stack_coord_name + ) + predictors_flattened = _flatten_predictors( + predictors, dims_to_flatten, stack_coord_name + ) + _check_coords_match(target_flattened, predictors_flattened, stack_coord_name) + + return predictors_flattened, target_flattened, stack_coord_name + + +def _loop_levels(inp, levels): + # annoyingly, there doesn't seem to be an inbuilt solution for this + # https://github.com/pydata/xarray/issues/2438 + def _yield_level(inph, levels_left, out_names): + for name, values in inph.groupby(levels_left[0]): + out_names_here = out_names + [name] + if len(levels_left) == 1: + yield tuple(out_names_here), values + else: + yield from _yield_level(values, levels_left[1:], out_names_here) + + for names, values in _yield_level(inp, levels, []): + yield names, values + + +def _select_auto_regressive_process_order( + target, + maxlag, + ic, + scenario_level="scenario", + ensemble_member_level="ensemble_member", + q=50, + interpolation="nearest", +): + """ + + interpolation : str + Passed to :func:`numpy.percentile`. Interpolation is not a good way to + go here because it could lead to an AR order that wasn't actually chosen by any run. We recommend using the default value i.e. "nearest". + """ + store = [] + + for (scenario, ensemble_member), values in _loop_levels( + target, (scenario_level, ensemble_member_level) + ): + orders = AutoRegression1DOrderSelection().calibrate( + values, maxlag=maxlag, ic=ic + ) + + # orders can be None + keep_order = np.nan if orders is None else orders[-1] + + store.append( + { + "scenario": scenario, + "ensemble_member": ensemble_member, + "order": keep_order, + } + ) + + store = pd.DataFrame(store).set_index(["scenario", "ensemble_member"]) + res = ( + store.groupby("scenario")["order"] + # first operation gives result by scenario (i.e. over ensemble members) + .quantile(q=q / 100, interpolation=interpolation) + # second one gives result over all scenarios + .quantile(q=q / 100, interpolation=interpolation) + ) + + return res + + +def _derive_auto_regressive_process_parameters( + target, order, scenario_level="scenario", ensemble_member_level="ensemble_member" +): + store = [] + for (scenario, ensemble_member), values in _loop_levels( + target, (scenario_level, ensemble_member_level) + ): + parameters = AutoRegression1D().calibrate(values, order=order) + parameters["scenario"] = scenario + parameters["ensemble_member"] = ensemble_member + store.append(parameters) + + store = pd.DataFrame(store).set_index(["scenario", "ensemble_member"]) + + def _axis_mean(inp): + return inp.apply(np.mean, axis=0) + + res = ( + store.groupby("scenario") + # first operation gives result by scenario (i.e. over ensemble members) + .apply(_axis_mean) + # second one gives result over all scenarios + .apply(np.mean, axis=0).to_dict() + ) + + return res + + +def calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members( + target, + maxlag=12, + ic="bic", +): + ar_order = _select_auto_regressive_process_order(target, maxlag, ic) + ar_params = _derive_auto_regressive_process_parameters(target, ar_order) + + return ar_params + + +def calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_scenarios_and_ensemble_members( + target, + localisation_radii, + max_cross_validation_iterations=30, + gridpoint_dim_name="gridpoint", +): + gridpoint_autoregression_parameters = { + gridpoint: _derive_auto_regressive_process_parameters(gridpoint_vals, order=1) + for gridpoint, gridpoint_vals in target.groupby("gridpoint") + } + + geodist = mesmer.geospatial.geodist_exact(target.lon, target.lat) + gaspari_cohn_correlation_matrices = mesmer.stats.gaspari_cohn_correlation_matrices( + geodist, localisation_radii + ) + + localised_empirical_covariance_matrix = ( + _calculate_localised_empirical_covariance_matrix( + target, + localisation_radii, + gaspari_cohn_correlation_matrices, + max_cross_validation_iterations, + gridpoint_dim_name=gridpoint_dim_name, + ) + ) + + gridpoint_autoregression_coeffcients = np.hstack( + [v["lag_coefficients"] for v in gridpoint_autoregression_parameters.values()] + ) + + localised_empirical_covariance_matrix_with_ar1_errors = ( + 1 - gridpoint_autoregression_coeffcients**2 + ) * localised_empirical_covariance_matrix + + return localised_empirical_covariance_matrix_with_ar1_errors + + +def _calculate_localised_empirical_covariance_matrix( + target, + localisation_radii, + gaspari_cohn_correlation_matrices, + max_cross_validation_iterations, + gridpoint_dim_name="gridpoint", +): + dims_to_flatten = [d for d in target.dims if d != gridpoint_dim_name] + target_flattened, stack_coord_name = _flatten(target, dims_to_flatten) + target_flattened = target_flattened.transpose(stack_coord_name, gridpoint_dim_name) + + number_samples = target_flattened[stack_coord_name].shape[0] + number_iterations = min([number_samples, max_cross_validation_iterations]) + + # setup cross-validation stuff + index_cross_validation_out = np.zeros( + [number_iterations, number_samples], dtype=bool + ) + + for i in range(number_iterations): + index_cross_validation_out[i, i::max_cross_validation_iterations] = True + + # No idea what these are either + log_likelihood_cross_validation_sum_max = -10000 + + for lr in localisation_radii: + log_likelihood_cross_validation_sum = 0 + + for i in range(number_iterations): + # extract folds (no idea why these are called folds) + target_estimator = target_flattened.isel( + **{stack_coord_name: ~index_cross_validation_out[i]} + ).values + target_cross_validation = target_flattened.isel( + **{stack_coord_name: index_cross_validation_out[i]} + ).values + # selecting relevant weights goes in here + + empirical_covariance = np.cov(target_estimator, rowvar=False) + # must be a better way to handle ensuring that the dimensions line up correctly (rather than + # just cheating by using `.values`) + empirical_covariance_localised = ( + empirical_covariance * gaspari_cohn_correlation_matrices[lr].values + ) + + # calculate likelihood of cross validation samples + log_likelihood_cross_validation_samples = ( + scipy.stats.multivariate_normal.logpdf( + target_cross_validation, + mean=np.zeros(gaspari_cohn_correlation_matrices[lr].shape[0]), + cov=empirical_covariance_localised, + allow_singular=True, + ) + ) + log_likelihood_cross_validation_samples_weighted_sum = ( + np.average( + log_likelihood_cross_validation_samples, + # weights=wgt_scen_eq_cv # TODO: weights handling + ) + * log_likelihood_cross_validation_samples.shape[0] + ) + + # add to full sum over all folds + log_likelihood_cross_validation_sum += ( + log_likelihood_cross_validation_samples_weighted_sum + ) + + if ( + log_likelihood_cross_validation_sum + > log_likelihood_cross_validation_sum_max + ): + log_likelihood_cross_validation_sum_max = ( + log_likelihood_cross_validation_sum + ) + else: + # experience tells us that once we start selecting large localisation radii, performance + # will not improve ==> break (reduces computational effort and number of singular matrices + # encountered) + break + + # TODO: replace print with logging + print(f"Selected localisation radius: {lr}") + + empirical_covariance = np.cov(target_flattened.values, rowvar=False) + empirical_covariance_localised = ( + empirical_covariance * gaspari_cohn_correlation_matrices[lr].values + ) + + return empirical_covariance_localised diff --git a/tests/integration/test_prototype.py b/tests/integration/test_prototype.py new file mode 100644 index 00000000..123de595 --- /dev/null +++ b/tests/integration/test_prototype.py @@ -0,0 +1,447 @@ +import numpy as np +import pytest +import xarray as xr +from statsmodels.tsa.arima_process import ArmaProcess + +import mesmer +from mesmer.calibrate_mesmer.train_gv import train_gv +from mesmer.calibrate_mesmer.train_lt import train_lt +from mesmer.calibrate_mesmer.train_lv import train_lv +from mesmer.prototype.calibrate import LinearRegression +from mesmer.prototype.calibrate_multiple import ( + calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members, + calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_scenarios_and_ensemble_members, + flatten_predictors_and_target, +) + + +class _MockConfig: + def __init__( + self, + method_lt="OLS", + method_lv="OLS_AR1_sci", + method_gv="AR", + separate_gridpoints=True, + weight_scenarios_equally=True, + target_variable="tas", + cross_validation_max_iterations=30, + ): + self.methods = {} + self.methods[target_variable] = {} + self.methods[target_variable]["lt"] = method_lt + self.methods[target_variable]["lv"] = method_lv + self.methods[target_variable]["gv"] = method_gv + + # this has to be set but isn't actually used + self.preds = {} + self.preds[target_variable] = {} + self.preds[target_variable]["gv"] = None + + self.method_lt_each_gp_sep = separate_gridpoints + self.wgt_scen_tr_eq = weight_scenarios_equally + + self.max_iter_cv = cross_validation_max_iterations + + +def _do_legacy_run_train_lt( + emulator_tas, + emulator_tas_squared, + emulator_hfds, + global_variability, + esm_tas, + cfg, +): + preds_legacy = {} + for name, vals in ( + ("gttas", emulator_tas), + ("gttas2", emulator_tas_squared), + ("gthfds", emulator_hfds), + ("gvtas", global_variability), + ): + preds_legacy[name] = {} + for scenario, vals_scen in vals.groupby("scenario"): + # we have to force this to have an extra dimension, run, for legacy + # to work although this really isn't how it should be because it + # means you have some weird reshaping to do at a really low level + preds_legacy[name][scenario] = vals_scen.dropna(dim="time").values[ + np.newaxis, : + ] + + targs_legacy = {} + for name, vals in (("tas", esm_tas),): + targs_legacy[name] = {} + for scenario, vals_scen in vals.groupby("scenario"): + # we have to force this to have an extra dimension, run, for legacy + # to work although this really isn't how it should be + # order of dimensions is very important for legacy too + targs_legacy[name][scenario] = vals_scen.T.dropna(dim="time").values[ + np.newaxis, :, : + ] + + res_legacy = train_lt( + preds_legacy, + targs_legacy, + esm="esm_name", + cfg=cfg, + save_params=False, + ) + + return res_legacy + + +def test_prototype_train_lt(): + time = [1850, 1950, 2014, 2015, 2050, 2100, 2300] + scenarios = ["hist", "ssp126"] + + pred_dims = ["scenario", "time"] + pred_coords = dict( + time=time, + scenario=scenarios, + ) + + emulator_tas = xr.DataArray( + np.array( + [ + [0, 0.5, 1, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, 1.1, 1.5, 1.4, 1.2], + ] + ), + dims=pred_dims, + coords=pred_coords, + ) + emulator_tas_squared = emulator_tas**2 + global_variability = xr.DataArray( + np.array( + [ + [-0.1, 0.1, 0.03, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, 0.04, 0.2, -0.03, 0.0], + ] + ), + dims=pred_dims, + coords=pred_coords, + ) + emulator_hfds = xr.DataArray( + np.array( + [ + [0.5, 1.5, 2.0, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, 2.1, 2.0, 1.5, 0.4], + ] + ), + dims=pred_dims, + coords=pred_coords, + ) + + # we wouldn't actually start like this, but we'd write a utils function + # to simply go from lat-lon to gridpoint and back + targ_dims = ["scenario", "gridpoint", "time"] + targ_coords = dict( + time=time, + scenario=scenarios, + gridpoint=[0, 1], + lat=(["gridpoint"], [-60, 60]), + lon=(["gridpoint"], [120, 240]), + ) + esm_tas = xr.DataArray( + np.array( + [ + [ + [0.6, 1.6, 2.6, np.nan, np.nan, np.nan, np.nan], + [0.43, 1.13, 2.21, np.nan, np.nan, np.nan, np.nan], + ], + [ + [np.nan, np.nan, np.nan, 2.11, 2.01, 1.54, 1.22], + [np.nan, np.nan, np.nan, 2.19, 2.04, 1.53, 1.21], + ], + ] + ), + dims=targ_dims, + coords=targ_coords, + ) + + res_legacy = _do_legacy_run_train_lt( + emulator_tas, + emulator_tas_squared, + emulator_hfds, + global_variability, + esm_tas, + cfg=_MockConfig(), + ) + + ( + predictors_flattened, + target_flattened, + stack_coord_name, + ) = flatten_predictors_and_target( + predictors={ + "emulator_tas": emulator_tas, + "emulator_tas_squared": emulator_tas_squared, + "emulator_hfds": emulator_hfds, + "global_variability": global_variability, + }, + target=esm_tas, + ) + + res_updated = LinearRegression().calibrate( + target_flattened, + predictors_flattened, + stack_coord_name, + ) + + # check that calibrated parameters match for each predictor variable + for updated_name, legacy_vals in ( + ("emulator_tas", res_legacy[0]["coef_gttas"]["tas"]), + ("emulator_tas_squared", res_legacy[0]["coef_gttas2"]["tas"]), + ("emulator_hfds", res_legacy[0]["coef_gthfds"]["tas"]), + ("global_variability", res_legacy[1]["coef_gvtas"]["tas"]), + ("intercept", res_legacy[0]["intercept"]["tas"]), + ): + np.testing.assert_allclose(res_updated.sel(predictor=updated_name), legacy_vals) + + +def _do_legacy_run_train_gv( + esm_tas_global_variability, + cfg, +): + targs_legacy = {} + var_name = "tas" + + targs_legacy = {} + for scenario, vals_scen in esm_tas_global_variability.groupby("scenario"): + targs_legacy[scenario] = ( + vals_scen.T.dropna(dim="time").transpose("ensemble_member", "time").values + ) + + res_legacy = train_gv( + targs_legacy, + targ=var_name, + esm="esm_name", + cfg=cfg, + save_params=False, + max_lag=2, + ) + + return res_legacy + + +@pytest.mark.parametrize( + "ar", + ( + [1, 0.5, 0.3], + [1, 0.5, 0.3, 0.3, 0.7], + [0.9, 1, 0.2, -0.1], + ), +) +def test_prototype_train_gv(ar): + time_history = np.arange(1850, 2014 + 1) + time_scenario = np.arange(2015, 2100 + 1) + time = np.concatenate((time_history, time_scenario)) + + magnitude = np.array([0.1]) + + scenarios = ["hist", "ssp126"] + + targ_dims = ["scenario", "ensemble_member", "time"] + targ_coords = dict( + time=time, + scenario=scenarios, + ensemble_member=["r1i1p1f1", "r2i1p1f1"], + ) + + def _get_history_sample(): + return np.concatenate( + [ + ArmaProcess(ar, magnitude).generate_sample(nsample=time_history.size), + np.full(time_scenario.size, np.nan), + ] + ) + + def _get_scenario_sample(): + return np.concatenate( + [ + np.full(time_history.size, np.nan), + ArmaProcess(ar, magnitude).generate_sample(nsample=len(time_scenario)), + ] + ) + + data = np.array( + [ + [ + _get_history_sample(), + _get_history_sample(), + ], + [ + _get_scenario_sample(), + _get_scenario_sample(), + ], + ], + ) + + esm_tas_global_variability = xr.DataArray( + data, + dims=targ_dims, + coords=targ_coords, + ) + + res_legacy = _do_legacy_run_train_gv( + esm_tas_global_variability, + cfg=_MockConfig(), + ) + + res_updated = ( + calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members( + esm_tas_global_variability, + maxlag=2, + ) + ) + + for key, comparison in ( + ("intercept", res_legacy["AR_int"]), + ("lag_coefficients", res_legacy["AR_coefs"]), + ("standard_innovations", res_legacy["AR_std_innovs"]), + ): + np.testing.assert_allclose(res_updated[key], comparison) + + +def _do_legacy_run_train_lv( + esm_tas_residual_local_variability, + localisation_radii, + cfg, +): + targs_legacy = {"tas": {}} + for scenario, vals_scen in esm_tas_residual_local_variability.groupby("scenario"): + targs_legacy["tas"][scenario] = ( + vals_scen.T.dropna(dim="time") + .transpose("ensemble_member", "time", "gridpoint") + .values + ) + + geodist = mesmer.geospatial.geodist_exact( + esm_tas_residual_local_variability.lon, esm_tas_residual_local_variability.lat + ) + gaspari_cohn_correlation_matrices = mesmer.stats.gaspari_cohn_correlation_matrices( + geodist, localisation_radii + ) + + gaspari_cohn_correlation_matrices = { + k: v.values for k, v in gaspari_cohn_correlation_matrices.items() + } + aux = {"phi_gc": gaspari_cohn_correlation_matrices} + + res_legacy = train_lv( + preds={}, + targs=targs_legacy, + esm="test", + cfg=cfg, + save_params=False, + aux=aux, + params_lv={}, # unclear why this is passed in + ) + + return res_legacy + + +# how train_lv works: +# 1. AR1 process for each individual gridpoint (use calibrate_auto_regressive_process_multiple_scenarios_and_ensemble_members) +# 2. find localised empirical covariance matrix +# 3. combine AR1 and localised empirical covariance matrix to get localised empirical covariance matrix +# of innovations (i.e. errors) which can be used for later draws (again, with a bit of custom code +# that feels like it should really be done using an existing library) + + +def test_prototype_train_lv(): + # input is residual local variability we want to reproduce (for Leah, residual + # means after removing local trend and local variability due to global variability + # but it could be whatever in reality) + + # also input the phi gc stuff (that's part of the calibration but doesn't go in the train + # lv function for some reason --> put it in calibrate_local_variability prototype function + # although split that function to allow for pre-calculating distance between points in future + # as a performance boost) + + localisation_radii = np.arange(700, 2000, 1000) + + # see how much code I can reuse for the AR1 calibration + time_history = np.arange(1850, 2014 + 1) + time_scenario = np.arange(2015, 2100 + 1) + time = np.concatenate((time_history, time_scenario)) + scenarios = ["hist", "ssp126"] + + # we wouldn't actually start like this, but we'd write a utils function + # to simply go from lat-lon to gridpoint and back + targ_dims = ["scenario", "ensemble_member", "gridpoint", "time"] + targ_coords = dict( + time=time, + scenario=scenarios, + ensemble_member=["r1i1p1f1", "r2i1p1f1"], + gridpoint=[0, 1], + lat=(["gridpoint"], [-60, 60]), + lon=(["gridpoint"], [120, 240]), + ) + + ar = np.array([1]) + magnitude = np.array([0.05]) + + def _get_history_sample(): + return np.concatenate( + [ + ArmaProcess(ar, magnitude).generate_sample(nsample=time_history.size), + np.full(time_scenario.size, np.nan), + ] + ) + + def _get_scenario_sample(): + return np.concatenate( + [ + np.full(time_history.size, np.nan), + ArmaProcess(ar, magnitude).generate_sample(nsample=len(time_scenario)), + ] + ) + + esm_tas_residual_local_variability = xr.DataArray( + np.array( + [ + [ + [ + _get_history_sample(), + _get_history_sample(), + ], + [ + _get_history_sample(), + _get_history_sample(), + ], + ], + [ + [ + _get_scenario_sample(), + _get_scenario_sample(), + ], + [ + _get_scenario_sample(), + _get_scenario_sample(), + ], + ], + ] + ), + dims=targ_dims, + coords=targ_coords, + ) + + res_legacy = _do_legacy_run_train_lv( + esm_tas_residual_local_variability, + localisation_radii, + cfg=_MockConfig(), + ) + + res_updated = calibrate_auto_regressive_process_with_spatially_correlated_errors_multiple_scenarios_and_ensemble_members( + esm_tas_residual_local_variability, + localisation_radii, + ) + + # check localised_empirical_covariance_matrix_with_ar1_errors + np.testing.assert_allclose(res_updated, res_legacy["loc_ecov_AR1_innovs"]["tas"]) + + +# things that aren't tested well: +# - what happens if ensemble member and scenario don't actually make a coherent set +# - units (should probably be using dataset rather than dataarray for inputs and outputs?) +# - weights