From fa9300c9deac21e45636364e8243c8cab16a2e1b Mon Sep 17 00:00:00 2001 From: Ben Orchard Date: Tue, 22 Aug 2023 17:33:55 +0100 Subject: [PATCH] continue refactoring data step --- src-new/forcing_computation.py | 79 ------- src-new/step_data.py | 18 -- .../new/common/bounding_box.py | 10 + src/gz21_ocean_momentum/new/data/cli.py | 39 ++++ src/gz21_ocean_momentum/new/data/coarsen.py | 198 ++++++++++++++++++ src/gz21_ocean_momentum/new/data/step.py | 75 +++++++ src/gz21_ocean_momentum/new/data/utils.py | 41 ++++ 7 files changed, 363 insertions(+), 97 deletions(-) delete mode 100644 src-new/forcing_computation.py delete mode 100644 src-new/step_data.py create mode 100644 src/gz21_ocean_momentum/new/common/bounding_box.py create mode 100644 src/gz21_ocean_momentum/new/data/cli.py create mode 100644 src/gz21_ocean_momentum/new/data/coarsen.py create mode 100644 src/gz21_ocean_momentum/new/data/step.py create mode 100644 src/gz21_ocean_momentum/new/data/utils.py diff --git a/src-new/forcing_computation.py b/src-new/forcing_computation.py deleted file mode 100644 index 97097c33..00000000 --- a/src-new/forcing_computation.py +++ /dev/null @@ -1,79 +0,0 @@ -from dataclasses import dataclass -from typing import assert_never -from typing import Optional -from typing import Literal -from typing import Tuple -import enum -from enum import Enum - -import xarray as xr - -@dataclass -class CO2Change(enum.Enum): - Control0 = enum.auto() - "TODO control" - - AnnualIncrease1 = enum.auto() - "TODO annual increase" - -@dataclass -class BoundingBox(): - lat_min: float - lat_max: float - long_min: float - long_max: float - -def preprocess( - grid, - surface_fields, - bounding_box: Optional[BoundingBox], - ntimes: Optional[int], - cyclize: bool, - factor: int, - *selected_vars: str, - ) -> Tuple[xr.Dataset, xr.Dataset]: - """ - Perform various preprocessing on a dataset. - """ - - # transform non-primary coords into vars - grid = grid.reset_coords()[["dxu", "dyu", "wet"]] - - if bounding_box is not None: - surface_fields = surface_fields .sel( - xu_ocean=slice(bounding_box.lat_min, bounding_box.lat_max, None), - yu_ocean=slice(bounding_box.long_min, bounding_box.long_max, None)) - grid = grid.sel( - xu_ocean=slice(bounding_box.lat_min, bounding_box.lat_max, None), - yu_ocean=slice(bounding_box.long_min, bounding_box.long_max, None)) - - if ntimes is not None: - surface_fields = surface_fields.isel(time=slice(0, ntimes)) - - if len(selected_vars) != 0: - surface_fields = surface_fields[list(selected_vars)] - - if cyclize: - logger.info("Cyclic data... Making the dataset cyclic along longitude...") - surface_fields = cyclize_dataset(surface_fields, "xu_ocean", factor) - grid = cyclize_dataset(grid, "xu_ocean", factor) - - # rechunk along the cyclized dimension - surface_fields = surface_fields.chunk({"xu_ocean": -1}) - grid = grid.chunk({"xu_ocean": -1}) - - # TODO should this be earlier? later? never? ??? - logger.debug("Getting grid data locally") - # grid data is saved locally, no need for dask - grid_data = grid_data.compute() - - return surface_fields, grid - -def prepare_cmip(resolution_degrading_factor, make_cyclic): - return 0 - -def compute_forcing(resolution_degrading_factor): - """ - Returns an xarray. (TODO dataset or dataarray?) - """ - return 0 diff --git a/src-new/step_data.py b/src-new/step_data.py deleted file mode 100644 index 28be6e69..00000000 --- a/src-new/step_data.py +++ /dev/null @@ -1,18 +0,0 @@ -def run_data_step_cm2_6( - catalog_url, - bounding_box: Optional[BoundingBox], - ntimes: Optional[int], - cyclize: bool, - factor: int - ) -> xr.Dataset: - """Run data step on CM2.6 dataset.""" - catalog = intake.open_catalog(catalog_url) - grid = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_control_ocean_surface - grid = grid.to_dask() - if co2_increase: - surface_fields = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_control_ocean_surface - else: - surface_fields = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_one_percent_ocean_surface - surface_fields = surface_fields.to_dask() - preprocess(grid, surface_fields, bounding_box, ntimes, cyclize, factor, - "usurf", "vsurf") diff --git a/src/gz21_ocean_momentum/new/common/bounding_box.py b/src/gz21_ocean_momentum/new/common/bounding_box.py new file mode 100644 index 00000000..445273c5 --- /dev/null +++ b/src/gz21_ocean_momentum/new/common/bounding_box.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from typing import Optional +from typing import Tuple + +@dataclass +class BoundingBox(): + lat_min: float + lat_max: float + long_min: float + long_max: float diff --git a/src/gz21_ocean_momentum/new/data/cli.py b/src/gz21_ocean_momentum/new/data/cli.py new file mode 100644 index 00000000..02324dab --- /dev/null +++ b/src/gz21_ocean_momentum/new/data/cli.py @@ -0,0 +1,39 @@ +import gz21_ocean_momentum.new.data.step as step +from gz21_ocean_momentum.new.common.bounding_box import BoundingBox + +import configargparse + +DESCRIPTION = "Read data from the CM2.6 and \ + apply coarse graining. Stores the resulting dataset into an MLFLOW \ + experiment within a specific run." + +p = configargparse.ArgParser() +p.add("--config-file", is_config_file=True, help="config file path") +p.add("--out-dir", type=str, required=True) +p.add("--lat-min", type=float, required=True) +p.add("--lat-max", type=float, required=True) +p.add("--long-min", type=float, required=True) +p.add("--long-max", type=float, required=True) +p.add("--cyclize", action="store_true", help="global data; make cyclic along longitude") +p.add("--ntimes", type=int, required=True, help="number of days (TODO)") +p.add("--co2-increase", action="store_true", help="use 1%% annual CO2 increase CM2.6 dataset. By default, uses control (no increase)") +p.add("--factor", type=int, help="resolution degradation factor") + +options = p.parse_args() + +# form bounding box from input arguments +bounding_box = BoundingBox( + options.lat_min, options.lat_max, + options.long_min, options.long_max) + +CATALOG_URL = "https://raw.githubusercontent.com/\ +pangeo-data/pangeo-datastore/\ +master/\ +intake-catalogs/master.yaml" + +surface_fields, grid = step.download_cm2_6(CATALOG_URL, options.co2_increase) +forcings = step.preprocess_and_compute_forcings( + grid, surface_fields, bounding_box, options.ntimes, options.cyclize, + options.factor, "usurf", "vsurf") + +forcings.to_zarr(options.out_dir) diff --git a/src/gz21_ocean_momentum/new/data/coarsen.py b/src/gz21_ocean_momentum/new/data/coarsen.py new file mode 100644 index 00000000..8c078c75 --- /dev/null +++ b/src/gz21_ocean_momentum/new/data/coarsen.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Routines for coarsening a dataset.""" + +import logging +import xarray as xr +from scipy.ndimage import gaussian_filter +import numpy as np + +def eddy_forcing( + u_v_dataset: xr.Dataset, + grid_data: xr.Dataset, + scale: int, + nan_or_zero: str = "zero", +) -> xr.Dataset: + """ + Compute the sub-grid forcing terms using mean coarse-graining. + + Parameters + ---------- + u_v_dataset : xarray Dataset + High-resolution velocity field. + grid_data : xarray Dataset + High-resolution grid details. + scale : float + factor (TODO) + nan_or_zero: str, optional + String set to either 'nan' or 'zero'. Determines whether we keep the + nan values in the initial surface velocities array or whether we + replace them by zeros before applying the procedure. + In the second case, remaining zeros after applying the procedure will + be replaced by nans for consistency. + The default is 'zero'. + + Returns + ------- + forcing : xarray Dataset + Dataset containing the low-resolution velocity field and forcing. + + TODO: we edit u_v_dataset, and for some reason we were returning it (but + calls were silently ignoring it, or something). + """ + # Replace nan values with zeros. + if nan_or_zero == "zero": + u_v_dataset = u_v_dataset.fillna(0.0) + + # Interpolate temperature + # interp_coords = dict(xt_ocean=u_v_dataset.coords['xu_ocean'], + # yt_ocean=u_v_dataset.coords['yu_ocean']) + # u_v_dataset['temp'] = u_v_dataset['surface_temperature'].interp( + # interp_coords) + + scale_filter = scale / 2 + # High res advection terms + adv = advections(u_v_dataset, grid_data) + # Filtered advections + filtered_adv = spatial_filter_dataset(adv, grid_data, scale_filter) + # Filtered u,v field and temperature + u_v_filtered = spatial_filter_dataset(u_v_dataset, grid_data, scale_filter) + # Advection term from filtered velocity field + adv_filtered = advections(u_v_filtered, grid_data) + # Forcing + forcing = adv_filtered - filtered_adv + forcing = forcing.rename({"adv_x": "S_x", "adv_y": "S_y"}) + # Merge filtered u,v, temperature and forcing terms + forcing = forcing.merge(u_v_filtered) + # TODO logging + #logging.debug(forcing) + + # Coarsen + forcing_coarse = forcing.coarsen( + {"xu_ocean": int(scale_filter), "yu_ocean": int(scale_filter)}, boundary="trim" + ) + + forcing_coarse = forcing_coarse.mean() + + if nan_or_zero == "zero": + # Replace zeros with nans for consistency + forcing_coarse = forcing_coarse.where(forcing_coarse["usurf"] != 0) + u_v_dataset = u_v_dataset.merge(adv) + filtered_adv = filtered_adv.rename({"adv_x": "f_adv_x", "adv_y": "f_adv_y"}) + adv_filtered = adv_filtered.rename({"adv_x": "adv_f_x", "adv_y": "adv_f_y"}) + u_v_filtered = u_v_filtered.rename({"usurf": "f_usurf", "vsurf": "f_vsurf"}) + u_v_dataset = xr.merge( + ( + u_v_dataset, + u_v_filtered, + adv, + filtered_adv, + adv_filtered, + forcing[["S_x", "S_y"]], + ) + ) + return forcing_coarse + +def advections(u_v_field: xr.Dataset, grid_data: xr.Dataset): + """ + Compute advection terms corresponding to the passed velocity field. + + Parameters + ---------- + u_v_field : xarray Dataset + Velocity field, must contains variables "usurf" and "vsurf" + grid_data : xarray Dataset + grid data, must contain variables "dxu" and "dyu" + + Returns + ------- + result : xarray Dataset + Advection components, under variable names "adv_x" and "adv_y" + """ + dxu = grid_data["dxu"] + dyu = grid_data["dyu"] + gradient_x = u_v_field.diff(dim="xu_ocean") / dxu + gradient_y = u_v_field.diff(dim="yu_ocean") / dyu + # Interpolate back the gradients + interp_coords = { + "xu_ocean": u_v_field.coords["xu_ocean"], + "yu_ocean": u_v_field.coords["yu_ocean"], + } + gradient_x = gradient_x.interp(interp_coords) + gradient_y = gradient_y.interp(interp_coords) + u, v = u_v_field["usurf"], u_v_field["vsurf"] + adv_x = u * gradient_x["usurf"] + v * gradient_y["usurf"] + adv_y = u * gradient_x["vsurf"] + v * gradient_y["vsurf"] + result = xr.Dataset({"adv_x": adv_x, "adv_y": adv_y}) + # TODO check if we can simply prevent the previous operation from adding + # chunks + # result = result.chunk(dict(xu_ocean=-1, yu_ocean=-1)) + return result + +def spatial_filter_dataset( + dataset: xr.Dataset, grid_info: xr.Dataset, sigma: float + ) -> xr.Dataset: + """ + Apply spatial filtering to the dataset across the spatial dimensions. + + Parameters + ---------- + dataset : xarray Dataset + Dataset to filter. First dimension must be time, followed by spatial dimensions + grid_info : xarray Dataset + grid data, must include variables "dxu" and "dyu" + sigma : float + Scale of the filtering, same unit as those of the grid (often, meters) + + Returns + ------- + filt_dataset : xarray Dataset + Filtered dataset + """ + area_u = grid_info["dxu"] * grid_info["dyu"] / 1e8 + dataset = dataset * area_u + # Normalisation term, so that if the quantity we filter is constant + # over the domain, the filtered quantity is constant with the same value + norm = xr.apply_ufunc( + lambda x: gaussian_filter(x, (sigma, sigma), mode="constant"), + area_u, + dask="parallelized", + output_dtypes=[ + float, + ], + ) + filtered = xr.apply_ufunc( + lambda x: spatial_filter(x, sigma), + dataset, + dask="parallelized", + output_dtypes=[ + float, + ], + ) + return filtered / norm + +def spatial_filter(data: np.ndarray, sigma: float): + """ + Apply a gaussian filter to spatial data. + + Apply scipy gaussian filter to along all dimensions except first one, which + corresponds to time. + + Parameters + ---------- + data : ndarray + Data to filter. + sigma : float + Unitless scale of the filter. + + Returns + ------- + result : ndarray + Filtered data + """ + result = np.zeros_like(data) + for t in range(data.shape[0]): + data_t = data[t, ...] + result_t = gaussian_filter(data_t, (sigma, sigma), mode="constant") + result[t, ...] = result_t + return result diff --git a/src/gz21_ocean_momentum/new/data/step.py b/src/gz21_ocean_momentum/new/data/step.py new file mode 100644 index 00000000..a89af1b9 --- /dev/null +++ b/src/gz21_ocean_momentum/new/data/step.py @@ -0,0 +1,75 @@ +import gz21_ocean_momentum.new.data.coarsen as coarsen +import gz21_ocean_momentum.new.data.utils as utils +from gz21_ocean_momentum.new.common.bounding_box import BoundingBox + +import xarray as xr +import intake + +from typing import Optional +from typing import Tuple + +def preprocess_and_compute_forcings( + grid: xr.Dataset, + surface_fields: xr.Dataset, + bounding_box: Optional[BoundingBox], + ntimes: Optional[int], + cyclize: bool, + resolution_degrading_factor: int, + *selected_vars: str, + ) -> xr.Dataset: + """ + Perform various preprocessing on a dataset. + """ + + # transform non-primary coords into vars + grid = grid.reset_coords()[["dxu", "dyu", "wet"]] + + if bounding_box is not None: + surface_fields = surface_fields.sel( + xu_ocean=slice(bounding_box.lat_min, bounding_box.lat_max, None), + yu_ocean=slice(bounding_box.long_min, bounding_box.long_max, None)) + grid = grid.sel( + xu_ocean=slice(bounding_box.lat_min, bounding_box.lat_max, None), + yu_ocean=slice(bounding_box.long_min, bounding_box.long_max, None)) + + if ntimes is not None: + surface_fields = surface_fields.isel(time=slice(0, ntimes)) + + if len(selected_vars) != 0: + surface_fields = surface_fields[list(selected_vars)] + + if cyclize: + # TODO logger + #logger.info("Cyclic data... Making the dataset cyclic along longitude...") + surface_fields = utils.cyclize( + surface_fields, "xu_ocean", resolution_degrading_factor) + grid = utils.cyclize( + grid, "xu_ocean", resolution_degrading_factor) + + # rechunk along the cyclized dimension + surface_fields = surface_fields.chunk({"xu_ocean": -1}) + grid = grid.chunk({"xu_ocean": -1}) + + # TODO should this be earlier? later? never? ??? + # TODO logger + #logger.debug("Getting grid data locally") + # grid data is saved locally, no need for dask + grid = grid.compute() + + # calculate eddy-forcing dataset for that particular patch + return coarsen.eddy_forcing(surface_fields, grid, resolution_degrading_factor) + +def download_cm2_6( + catalog_url: str, + co2_increase: bool, + ) -> Tuple[xr.Dataset, xr.Dataset]: + """Run data step on CM2.6 dataset.""" + catalog = intake.open_catalog(catalog_url) + grid = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_control_ocean_surface + grid = grid.to_dask() + if co2_increase: + surface_fields = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_control_ocean_surface + else: + surface_fields = catalog.ocean.GFDL_CM2_6.GFDL_CM2_6_one_percent_ocean_surface + surface_fields = surface_fields.to_dask() + return surface_fields, grid diff --git a/src/gz21_ocean_momentum/new/data/utils.py b/src/gz21_ocean_momentum/new/data/utils.py new file mode 100644 index 00000000..09c77c98 --- /dev/null +++ b/src/gz21_ocean_momentum/new/data/utils.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Utilities for handling data.""" +import mlflow +import xarray as xr + +def cyclize(ds: xr.Dataset, coord_name: str, nb_points: int): + """ + Generate a cyclic dataset from non-cyclic input. + + Return a cyclic dataset, with nb_points added on each end, along + the coordinate specified by coord_name. + + Parameters + ---------- + ds : xr.Dataset + Dataset to process. + coord_name : str + Name of the coordinate along which the data is made cyclic. + nb_points : int + Number of points added on each end. + + Returns + ------- + New extended dataset. + """ + # TODO make this flexible + cycle_length = 360.0 + left = ds.roll({coord_name: nb_points}, roll_coords=True) + right = ds.roll({coord_name: nb_points}, roll_coords=True) + right = right.isel({coord_name: slice(0, 2 * nb_points)}) + left[coord_name] = xr.concat( + (left[coord_name][:nb_points] - cycle_length, left[coord_name][nb_points:]), + coord_name, + ) + right[coord_name] = xr.concat( + (right[coord_name][:nb_points], right[coord_name][nb_points:] + cycle_length), + coord_name, + ) + new_ds = xr.concat((left, right), coord_name) + return new_ds