Skip to content

Commit

Permalink
various cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
raehik committed Nov 30, 2023
1 parent 7edc316 commit 998cf51
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 73 deletions.
11 changes: 6 additions & 5 deletions src/gz21_ocean_momentum/cli/data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import gz21_ocean_momentum.step.data.lib as lib
import gz21_ocean_momentum.lib.data as lib
import gz21_ocean_momentum.common.cli as cli
from gz21_ocean_momentum.common.bounding_box import BoundingBox
import gz21_ocean_momentum.common.bounding_box as bounding_box
Expand Down Expand Up @@ -60,14 +60,15 @@
logger.debug("dropping irrelevant data variables...")
surface_fields = surface_fields[["usurf", "vsurf"]]

if options.ntimes is not None:
logger.info(f"slicing {options.ntimes} time points...")
surface_fields = surface_fields.isel(time=slice(options.ntimes))

logger.info("selecting input data bounding box...")
surface_fields = bounding_box.bound_dataset("yu_ocean", "xu_ocean", surface_fields, bbox)
grid = bounding_box.bound_dataset("yu_ocean", "xu_ocean", grid, bbox)

# TODO 2023-11-29 raehik: original bounded first, sliced (immediately) after
if options.ntimes is not None:
logger.info(f"slicing {options.ntimes} time points...")
surface_fields = surface_fields.isel(time=slice(options.ntimes))

logger.debug("placing grid dataset into local memory...")
grid = grid.compute()

Expand Down
31 changes: 8 additions & 23 deletions src/gz21_ocean_momentum/cli/infer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,31 @@
import configargparse

import gz21_ocean_momentum.common.cli as cli
import logging
from dask.diagnostics import ProgressBar
from gz21_ocean_momentum.utils import TaskInfo

import gz21_ocean_momentum.lib.model as lib
from gz21_ocean_momentum.data.datasets import (
pytorch_dataset_from_cm2_6_forcing_dataset,
#DatasetPartitioner,
DatasetTransformer,
DatasetWithTransform,
ComposeTransforms,
)

import xarray as xr

import torch
from torch.utils.data import DataLoader

# TODO hardcode submodel, transformation, NN loss function
# unlikely for a CLI we need to provide dynamic code loading -- let's just give
# options
# we could enable such "dynamic loading" in the "library" interface!-- but, due
# to the class-based setup, it's a little complicated for a user to come in with
# their own code for some of these, and it needs documentation. so a task for
# later
import gz21_ocean_momentum.models.models1 as model
import gz21_ocean_momentum.models.submodels as submodels
import gz21_ocean_momentum.models.transforms as transforms
import gz21_ocean_momentum.train.losses as loss_funcs
from gz21_ocean_momentum.inference.utils import predict_lazy_cm2_6
#from gz21_ocean_momentum.train.base import Trainer

submodel = submodels.transform3

DESCRIPTION = """
_cli_desc = """
Use a trained GZ21 neural net to predict forcing for input ocean velocity data.
This script is intended as example of how use the GZ21 neural net, generating
Expand All @@ -54,21 +46,21 @@
into your GCM of choice.
"""

p = configargparse.ArgParser(description=DESCRIPTION)
p = configargparse.ArgParser(description=_cli_desc)
p.add("--config-file", is_config_file=True, help="config file path")

p.add("--input-data-dir", type=str, required=True, help="path to input ocean velocity data, in zarr format (folder)")
p.add("--model-state-dict-file", type=str, required=True, help="model state dict file (*.pth)")
p.add("--out-dir", type=str, required=True, help="folder to save forcing predictions dataset to (in zarr format)")

p.add("--device", type=str, default="cuda", help="neural net device (e.g. cuda, cuda:0, cpu)")
p.add("--splits", type=int)

options = p.parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

cli.fail_if_path_is_nonempty_dir(
1, f"--out-dir \"{options.out_dir}\" invalid", options.out_dir)

# ---

logger.info("loading input (coarse) ocean momentum data...")
Expand All @@ -77,14 +69,7 @@
with ProgressBar(), TaskInfo("Applying transforms to dataset"):
ds_computed_xr = submodel.fit_transform(ds_computed_xr)

ds_computed_torch = pytorch_dataset_from_cm2_6_forcing_dataset(ds_computed_xr)

logger.info("performing various dataset transforms...")
features_transform_ = ComposeTransforms()
targets_transform_ = ComposeTransforms()
transform = DatasetTransformer(features_transform_, targets_transform_)
dataset = DatasetWithTransform(ds_computed_torch, transform)

dataset = lib.gz21_train_data_subdomain_xr_to_torch(ds_computed_xr)
loader = DataLoader(dataset)

criterion = loss_funcs.HeteroskedasticGaussianLossV2(dataset.n_targets)
Expand Down
2 changes: 1 addition & 1 deletion src/gz21_ocean_momentum/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import gz21_ocean_momentum.common.cli as cli
import gz21_ocean_momentum.common.assorted as common
import gz21_ocean_momentum.common.bounding_box as bounding_box
import gz21_ocean_momentum.unsorted.train_data_xr_to_pytorch as lib
import gz21_ocean_momentum.lib.model as lib
import gz21_ocean_momentum.models.submodels as submodels
import gz21_ocean_momentum.models.transforms as transforms
import gz21_ocean_momentum.models.models1 as model
Expand Down
19 changes: 2 additions & 17 deletions src/gz21_ocean_momentum/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def transform(self, x):
return np.concatenate((left, x, right), axis=self.axis)

def transform_coordinate(self, coords, dim):
print(f"{dim}, {self.dim_name}")
if dim == self.dim_name:
left = coords[-self.nb_points :] - self.length
right = coords[: self.nb_points] + self.length
Expand Down Expand Up @@ -697,8 +696,6 @@ def __len__(self):
Number of samples of the dataset.
"""
print("xrrawdataset len called")
print(len(self.xr_dataset[self._index]))
try:
return len(self.xr_dataset[self._index])
except KeyError as e:
Expand All @@ -715,19 +712,6 @@ def __getattr__(self, attr_name):
return getattr(self.xr_dataset, attr_name)
raise AttributeError()

def pytorch_dataset_from_cm2_6_forcing_dataset(ds: xr.Dataset) -> torch.Dataset:
"""Obtain a PyTorch `Dataset` view over an xarray dataset, specifically for
CM2.6 data annotated with forcings in `S_x` and `S_y`.
The same snippet is used for both training and inference."""
ds_torch = RawDataFromXrDataset(ds)
ds_torch.index = "time"
ds_torch.add_input("usurf")
ds_torch.add_input("vsurf")
ds_torch.add_output("S_x")
ds_torch.add_output("S_y")
return ds_torch

class DatasetWithTransform:
def __init__(self, dataset, transform: DatasetTransformer):
self.dataset = dataset
Expand Down Expand Up @@ -797,7 +781,6 @@ def __getattr__(self, attr):
raise AttributeError()

def __len__(self):
print("len on datasetwithtransform")
return len(self.dataset)

def add_transforms_from_model(self, model):
Expand Down Expand Up @@ -914,6 +897,8 @@ class ConcatDataset_(ConcatDataset):
- enforces the concatenated dataset to have the same shapes
- passes on attributes (from the first dataset, assuming they are
equal accross concatenated datasets)
TODO input datasets need to have .height, .width
"""

def __init__(self, datasets):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Common functions relating to neural net model, training data.

import xarray as xr
import numpy as np
import torch.utils.data as torch

from gz21_ocean_momentum.common.assorted import at_idx_pct

from gz21_ocean_momentum.data.datasets import (
DatasetWithTransform,
DatasetTransformer,
Expand All @@ -11,7 +15,11 @@
ComposeTransforms,
)

def cm26_xarray_to_torch(ds_xr: xr.Dataset):
def cm26_xarray_to_torch(ds_xr: xr.Dataset) -> torch.Dataset:
"""
Obtain a PyTorch `Dataset` view over an xarray dataset, specifically for
CM2.6 ocean velocity data annotated with forcings in `S_x` and `S_y`.
"""
ds_torch = RawDataFromXrDataset(ds_xr)
ds_torch.index = "time"
ds_torch.add_input("usurf")
Expand All @@ -20,7 +28,7 @@ def cm26_xarray_to_torch(ds_xr: xr.Dataset):
ds_torch.add_output("S_y")
return ds_torch

def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset):
def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset) -> torch.Dataset:
"""
Convert GZ21 training data (coarsened CM2.6 data with diagnosed forcings)
into a PyTorch dataset.
Expand All @@ -39,22 +47,6 @@ def gz21_train_data_subdomain_xr_to_torch(ds_xr: xr.Dataset):

return ds_torch_with_transform

def at_idx_pct(pct: float, a) -> int:
"""
Obtain the index into the given list-like to the given percent.
No interpolation is performed: we choose the leftmost closest index i.e. the
result is floored.
e.g. `at_idx_pct(0.5, [0,1,2]) == 1`
Must be able to `len(a)`.
Invariant: `0<=pct<=1`.
Returns a valid index into `a`.
"""
return int(pct * len(a))

def prep_train_test_dataloaders(
dss: list,
pct_train_end: float,
Expand Down
9 changes: 0 additions & 9 deletions src/gz21_ocean_momentum/step/inference/lib.py

This file was deleted.

0 comments on commit 998cf51

Please sign in to comment.